diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 90e7ecc..371ab22 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -16,8 +16,8 @@ def make_policy(cfg): cfg_obs_encoder=cfg.obs_encoder, cfg_optimizer=cfg.optimizer, cfg_ema=cfg.ema, - n_obs_steps=cfg.n_obs_steps, - n_action_steps=cfg.n_action_steps, + # n_obs_steps=cfg.n_obs_steps, + # n_action_steps=cfg.n_action_steps, **cfg.policy, ) elif cfg.policy.name == "act": diff --git a/lerobot/common/policies/tdmpc/policy.py b/lerobot/common/policies/tdmpc/policy.py index 04aa5b1..2d547f2 100644 --- a/lerobot/common/policies/tdmpc/policy.py +++ b/lerobot/common/policies/tdmpc/policy.py @@ -110,7 +110,6 @@ class TDMPCPolicy(nn.Module): # self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr) self.model.eval() self.model_target.eval() - self.batch_size = cfg.batch_size self.register_buffer("step", torch.zeros(1)) @@ -325,7 +324,7 @@ class TDMPCPolicy(nn.Module): def _td_target(self, next_z, reward, mask): """Compute the TD-target from a reward and the observation at the following time step.""" next_v = self.model.V(next_z) - td_target = reward + self.cfg.discount * mask * next_v + td_target = reward + self.cfg.discount * mask * next_v.squeeze(2) return td_target def forward(self, batch, step): @@ -420,6 +419,8 @@ class TDMPCPolicy(nn.Module): # idxs = torch.cat([idxs, demo_idxs]) # weights = torch.cat([weights, demo_weights]) + batch_size = batch["index"].shape[0] + # TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels) # instead of currently (time/horizon, batch size, channels) which is not the pytorch convention # batch size b = 256, time/horizon t = 5 @@ -433,7 +434,7 @@ class TDMPCPolicy(nn.Module): # idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device) mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device) - weights = torch.ones_like(reward, dtype=torch.bool, device=reward.device) + weights = torch.ones(batch_size, dtype=torch.bool, device=reward.device) obses = { "rgb": batch["observation.image"], @@ -476,7 +477,7 @@ class TDMPCPolicy(nn.Module): td_targets = self._td_target(next_z, reward, mask) # Latent rollout - zs = torch.empty(horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device) + zs = torch.empty(horizon + 1, batch_size, self.cfg.latent_dim, device=self.device) reward_preds = torch.empty_like(reward, device=self.device) assert reward.shape[0] == horizon z = self.model.encode(obs) @@ -485,22 +486,21 @@ class TDMPCPolicy(nn.Module): for t in range(horizon): z, reward_pred = self.model.next(z, action[t]) zs[t + 1] = z - reward_preds[t] = reward_pred + reward_preds[t] = reward_pred.squeeze(1) with torch.no_grad(): v_target = self.model_target.Q(zs[:-1].detach(), action, return_type="min") # Predictions qs = self.model.Q(zs[:-1], action, return_type="all") + qs = qs.squeeze(3) value_info["Q"] = qs.mean().item() v = self.model.V(zs[:-1]) value_info["V"] = v.mean().item() # Losses - rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1, 1) - consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask).sum( - dim=0 - ) + rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1) + consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2) * loss_mask).sum(dim=0) reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0) q_value_loss, priority_loss = 0, 0 for q in range(self.cfg.num_q): @@ -508,7 +508,9 @@ class TDMPCPolicy(nn.Module): priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0) expectile = h.linear_schedule(self.cfg.expectile, step) - v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask).sum(dim=0) + v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile).squeeze(2) * loss_mask).sum( + dim=0 + ) total_loss = ( self.cfg.consistency_coef * consistency_loss @@ -517,7 +519,7 @@ class TDMPCPolicy(nn.Module): + self.cfg.value_coef * v_value_loss ) - weighted_loss = (total_loss.squeeze(1) * weights).mean() + weighted_loss = (total_loss * weights).mean() weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon)) has_nan = torch.isnan(weighted_loss).item() if has_nan: diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 6da62e1..811ee82 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -38,6 +38,7 @@ policy: horizon: ${horizon} n_obs_steps: ${n_obs_steps} + n_action_steps: ${n_action_steps} num_inference_steps: 100 obs_as_global_cond: ${obs_as_global_cond} # crop_shape: null diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml index 4fd2b6b..2ebaad9 100644 --- a/lerobot/configs/policy/tdmpc.yaml +++ b/lerobot/configs/policy/tdmpc.yaml @@ -36,7 +36,6 @@ policy: log_std_max: 2 # learning - batch_size: 256 max_buffer_size: 10000 horizon: 5 reward_coef: 0.5 diff --git a/poetry.lock b/poetry.lock index 98449df..e5105b7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -897,7 +897,7 @@ mujoco = "^2.3.7" type = "git" url = "git@github.com:huggingface/gym-aloha.git" reference = "HEAD" -resolved_reference = "ec7200831e36c14e343cf7d275c6b047f2fe9d11" +resolved_reference = "c636f05ba0d1760df94537da84c860be1487e17f" [[package]] name = "gym-pusht" diff --git a/tests/test_policies.py b/tests/test_policies.py index 82033b7..5d0c0d8 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -49,7 +49,7 @@ def test_policy(env_name, policy_name, extra_overrides): dataloader = torch.utils.data.DataLoader( dataset, num_workers=4, - batch_size=cfg.policy.batch_size, + batch_size=2, shuffle=True, pin_memory=DEVICE != "cpu", drop_last=True,