diff --git a/lerobot/common/tdmpc.py b/lerobot/common/tdmpc.py index d694a06df..03aeb7d70 100644 --- a/lerobot/common/tdmpc.py +++ b/lerobot/common/tdmpc.py @@ -129,19 +129,17 @@ class TDMPC(nn.Module): """Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag.""" if isinstance(obs, dict): obs = { - k: torch.tensor(o, dtype=torch.float32, device=self.device).unsqueeze(0) + k: o.detach().unsqueeze(0) for k, o in obs.items() } else: - obs = torch.tensor(obs, dtype=torch.float32, device=self.device).unsqueeze( - 0 - ) + obs = obs.detach().unsqueeze(0) z = self.model.encode(obs) if self.cfg.mpc: a = self.plan(z, t0=t0, step=step) else: a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0) - return a.cpu() + return a @torch.no_grad() def estimate_value(self, z, actions, horizon): @@ -324,7 +322,7 @@ class TDMPC(nn.Module): # trajectory t = 256, horizon h = 5 # (t h) ... -> h t ... batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous() - batch = batch.to("cuda") + batch = batch.to(self.device) FIRST_FRAME = 0 obs = { @@ -469,7 +467,11 @@ class TDMPC(nn.Module): weighted_loss = (total_loss.squeeze(1) * weights).mean() weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon)) - weighted_loss.backward() + has_nan = torch.isnan(weighted_loss).item() + if has_nan: + print(f"weighted_loss has nan: {total_loss=} {weights=}") + else: + weighted_loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False @@ -479,12 +481,16 @@ class TDMPC(nn.Module): if self.cfg.per: # Update priorities priorities = priority_loss.clamp(max=1e4).detach() - replay_buffer.update_priority( - idxs[:num_slices], - priorities[:num_slices], - ) - if demo_batch_size > 0: - demo_buffer.update_priority(demo_idxs, priorities[num_slices:]) + has_nan = torch.isnan(priorities).any().item() + if has_nan: + print(f"priorities has nan: {priorities=}") + else: + replay_buffer.update_priority( + idxs[:num_slices], + priorities[:num_slices], + ) + if demo_batch_size > 0: + demo_buffer.update_priority(demo_idxs, priorities[num_slices:]) # Update policy + target network _, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action) diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 5268ebff4..e4c5b9b12 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -36,13 +36,15 @@ def eval_policy( # render first frame before rollout rendering_callback(env) - rollout = env.rollout( - max_steps=max_steps, - policy=policy, - callback=rendering_callback if save_video else None, - auto_reset=False, - tensordict=tensordict, - ) + with torch.inference_mode(): + rollout = env.rollout( + max_steps=max_steps, + policy=policy, + callback=rendering_callback if save_video else None, + auto_reset=False, + tensordict=tensordict, + auto_cast_to_device=True, + ) # print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()])) ep_reward = rollout["next", "reward"].sum() ep_success = rollout["next", "success"].any() diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 13020b554..2d33d7b0e 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -99,10 +99,12 @@ def train(cfg: dict): is_offline = False # TODO: use SyncDataCollector for that? - rollout = env.rollout( - max_steps=cfg.episode_length, - policy=td_policy, - ) + with torch.no_grad(): + rollout = env.rollout( + max_steps=cfg.episode_length, + policy=td_policy, + auto_cast_to_device=True, + ) assert len(rollout) <= cfg.episode_length rollout["episode"] = torch.tensor( [online_episode_idx] * len(rollout), dtype=torch.int @@ -121,18 +123,14 @@ def train(cfg: dict): _step = min(step + len(rollout), cfg.train_steps) # Update model - train_metrics = {} - if is_offline: - for i in range(num_updates): - train_metrics.update(policy.update(offline_buffer, step + i)) - else: - for i in range(num_updates): - train_metrics.update( - policy.update( - online_buffer, - step + i // cfg.utd, - demo_buffer=offline_buffer if cfg.balanced_sampling else None, - ) + for i in range(num_updates): + if is_offline: + train_metrics = policy.update(offline_buffer, step + i) + else: + train_metrics = policy.update( + online_buffer, + step + i // cfg.utd, + demo_buffer=offline_buffer if cfg.balanced_sampling else None, ) # Log training metrics