From 14490148f32e3c98eed69baddd8774de9edff316 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Tue, 26 Nov 2024 11:58:29 +0000 Subject: [PATCH] added tdmpc2 to policy factory; shape fixes in tdmpc2 --- lerobot/common/policies/factory.py | 7 +++++++ .../common/policies/tdmpc2/modeling_tdmpc2.py | 17 ++++++----------- lerobot/common/policies/tdmpc2/tdmpc2_utils.py | 17 +++++++++-------- lerobot/scripts/train.py | 12 ++++++++++++ 4 files changed, 34 insertions(+), 19 deletions(-) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 5cb2fd52..f75baec3 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -51,6 +51,13 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]: from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy return TDMPCPolicy, TDMPCConfig + + elif name == "tdmpc2": + from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config + from lerobot.common.policies.tdmpc2.modeling_tdmpc2 import TDMPC2Policy + + return TDMPC2Policy, TDMPC2Config + elif name == "diffusion": from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy diff --git a/lerobot/common/policies/tdmpc2/modeling_tdmpc2.py b/lerobot/common/policies/tdmpc2/modeling_tdmpc2.py index 3288aa70..ad6d1546 100644 --- a/lerobot/common/policies/tdmpc2/modeling_tdmpc2.py +++ b/lerobot/common/policies/tdmpc2/modeling_tdmpc2.py @@ -389,7 +389,7 @@ class TDMPC2Policy( reward_loss = ( ( temporal_loss_coeffs - * soft_cross_entropy(reward_preds, reward, self.config) + * soft_cross_entropy(reward_preds, reward, self.config).mean(1) * ~batch["next.reward_is_pad"] * ~batch["observation.state_is_pad"][0] * ~batch["action_is_pad"] @@ -397,10 +397,11 @@ class TDMPC2Policy( .sum(0) .mean() ) + # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. ce_value_loss = 0.0 for i in range(self.config.q_ensemble_size): - ce_value_loss += soft_cross_entropy(q_preds_ensemble[i], td_targets, self.config) + ce_value_loss += soft_cross_entropy(q_preds_ensemble[i], td_targets, self.config).mean(1) q_value_loss = ( ( @@ -420,7 +421,6 @@ class TDMPC2Policy( # Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1. # We won't need these gradients again so detach. z_preds = z_preds.detach() - self.model.change_q_grad(mode=False) action_preds, _, log_pis, _ = self.model.pi(z_preds[:-1]) with torch.no_grad(): @@ -430,14 +430,9 @@ class TDMPC2Policy( self.scale.update(qs[0]) qs = self.scale(qs) - rho = torch.pow(self.config.temporal_decay_coeff, torch.arange(len(qs), device=qs.device)).unsqueeze( - -1 - ) - pi_loss = ( - (self.config.entropy_coef * log_pis - qs).mean(dim=(1, 2)) - * rho - # * temporal_loss_coeffs + (self.config.entropy_coef * log_pis - qs).mean(dim=2) + * temporal_loss_coeffs # `action_preds` depends on the first observation and the actions. * ~batch["observation.state_is_pad"][0] * ~batch["action_is_pad"] @@ -447,7 +442,7 @@ class TDMPC2Policy( self.config.consistency_coeff * consistency_loss + self.config.reward_coeff * reward_loss + self.config.value_coeff * q_value_loss - + self.config.pi_coeff * pi_loss + + pi_loss ) info.update( diff --git a/lerobot/common/policies/tdmpc2/tdmpc2_utils.py b/lerobot/common/policies/tdmpc2/tdmpc2_utils.py index 22f1ca06..2806b71d 100644 --- a/lerobot/common/policies/tdmpc2/tdmpc2_utils.py +++ b/lerobot/common/policies/tdmpc2/tdmpc2_utils.py @@ -75,9 +75,6 @@ def soft_cross_entropy(pred, target, cfg): """Computes the cross entropy loss between predictions and soft targets.""" pred = F.log_softmax(pred, dim=-1) target = two_hot(target, cfg) - import pudb - - pudb.set_trace() return -(target * pred).sum(-1, keepdim=True) @@ -137,16 +134,20 @@ def symexp(x): def two_hot(x, cfg): """Converts a batch of scalars to soft two-hot encoded targets for discrete regression.""" + + # x shape [horizon, num_features] if cfg.num_bins == 0: return x elif cfg.num_bins == 1: return symlog(x) x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax) - bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long() - bin_offset = (x - cfg.vmin) / cfg.bin_size - bin_idx.float() - soft_two_hot = torch.zeros(x.size(0), cfg.num_bins, device=x.device) - soft_two_hot.scatter_(1, bin_idx, 1 - bin_offset) - soft_two_hot.scatter_(1, (bin_idx + 1) % cfg.num_bins, bin_offset) + bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long() # shape [num_features] + bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx.float()).unsqueeze(-1) # shape [num_features , 1] + soft_two_hot = torch.zeros( + *x.shape, cfg.num_bins, device=x.device + ) # shape [horizon, num_features, num_bins] + soft_two_hot.scatter_(2, bin_idx.unsqueeze(-1), 1 - bin_offset) + soft_two_hot.scatter_(2, (bin_idx.unsqueeze(-1) + 1) % cfg.num_bins, bin_offset) return soft_two_hot diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index f60f904e..9c7df689 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -93,6 +93,18 @@ def make_optimizer_and_scheduler(cfg, policy): elif policy.name == "tdmpc": optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr) lr_scheduler = None + + elif policy.name == "tdmpc2": + params_group = [ + {"params": policy.model._encoder.parameters(), "lr": cfg.training.lr * cfg.training.enc_lr_scale}, + {"params": policy.model._dynamics.parameters()}, + {"params": policy.model._reward.parameters()}, + {"params": policy.model._Qs.parameters()}, + {"params": policy.model._pi.parameters(), "eps": 1e-5}, + ] + optimizer = torch.optim.Adam(params_group, lr=cfg.training.lr) + lr_scheduler = None + elif cfg.policy.name == "vqbet": from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler