offline training + online finetuning converge to 33 reward!

This commit is contained in:
Cadene
2024-02-18 01:23:44 +00:00
parent 0b4084f0f8
commit a5c305a7a4
3 changed files with 42 additions and 36 deletions

View File

@@ -129,19 +129,17 @@ class TDMPC(nn.Module):
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag.""" """Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
if isinstance(obs, dict): if isinstance(obs, dict):
obs = { obs = {
k: torch.tensor(o, dtype=torch.float32, device=self.device).unsqueeze(0) k: o.detach().unsqueeze(0)
for k, o in obs.items() for k, o in obs.items()
} }
else: else:
obs = torch.tensor(obs, dtype=torch.float32, device=self.device).unsqueeze( obs = obs.detach().unsqueeze(0)
0
)
z = self.model.encode(obs) z = self.model.encode(obs)
if self.cfg.mpc: if self.cfg.mpc:
a = self.plan(z, t0=t0, step=step) a = self.plan(z, t0=t0, step=step)
else: else:
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0) a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
return a.cpu() return a
@torch.no_grad() @torch.no_grad()
def estimate_value(self, z, actions, horizon): def estimate_value(self, z, actions, horizon):
@@ -324,7 +322,7 @@ class TDMPC(nn.Module):
# trajectory t = 256, horizon h = 5 # trajectory t = 256, horizon h = 5
# (t h) ... -> h t ... # (t h) ... -> h t ...
batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous() batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
batch = batch.to("cuda") batch = batch.to(self.device)
FIRST_FRAME = 0 FIRST_FRAME = 0
obs = { obs = {
@@ -469,7 +467,11 @@ class TDMPC(nn.Module):
weighted_loss = (total_loss.squeeze(1) * weights).mean() weighted_loss = (total_loss.squeeze(1) * weights).mean()
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon)) weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
weighted_loss.backward() has_nan = torch.isnan(weighted_loss).item()
if has_nan:
print(f"weighted_loss has nan: {total_loss=} {weights=}")
else:
weighted_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_( grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
@@ -479,12 +481,16 @@ class TDMPC(nn.Module):
if self.cfg.per: if self.cfg.per:
# Update priorities # Update priorities
priorities = priority_loss.clamp(max=1e4).detach() priorities = priority_loss.clamp(max=1e4).detach()
replay_buffer.update_priority( has_nan = torch.isnan(priorities).any().item()
idxs[:num_slices], if has_nan:
priorities[:num_slices], print(f"priorities has nan: {priorities=}")
) else:
if demo_batch_size > 0: replay_buffer.update_priority(
demo_buffer.update_priority(demo_idxs, priorities[num_slices:]) idxs[:num_slices],
priorities[:num_slices],
)
if demo_batch_size > 0:
demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
# Update policy + target network # Update policy + target network
_, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action) _, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action)

View File

@@ -36,13 +36,15 @@ def eval_policy(
# render first frame before rollout # render first frame before rollout
rendering_callback(env) rendering_callback(env)
rollout = env.rollout( with torch.inference_mode():
max_steps=max_steps, rollout = env.rollout(
policy=policy, max_steps=max_steps,
callback=rendering_callback if save_video else None, policy=policy,
auto_reset=False, callback=rendering_callback if save_video else None,
tensordict=tensordict, auto_reset=False,
) tensordict=tensordict,
auto_cast_to_device=True,
)
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()])) # print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
ep_reward = rollout["next", "reward"].sum() ep_reward = rollout["next", "reward"].sum()
ep_success = rollout["next", "success"].any() ep_success = rollout["next", "success"].any()

View File

@@ -99,10 +99,12 @@ def train(cfg: dict):
is_offline = False is_offline = False
# TODO: use SyncDataCollector for that? # TODO: use SyncDataCollector for that?
rollout = env.rollout( with torch.no_grad():
max_steps=cfg.episode_length, rollout = env.rollout(
policy=td_policy, max_steps=cfg.episode_length,
) policy=td_policy,
auto_cast_to_device=True,
)
assert len(rollout) <= cfg.episode_length assert len(rollout) <= cfg.episode_length
rollout["episode"] = torch.tensor( rollout["episode"] = torch.tensor(
[online_episode_idx] * len(rollout), dtype=torch.int [online_episode_idx] * len(rollout), dtype=torch.int
@@ -121,18 +123,14 @@ def train(cfg: dict):
_step = min(step + len(rollout), cfg.train_steps) _step = min(step + len(rollout), cfg.train_steps)
# Update model # Update model
train_metrics = {} for i in range(num_updates):
if is_offline: if is_offline:
for i in range(num_updates): train_metrics = policy.update(offline_buffer, step + i)
train_metrics.update(policy.update(offline_buffer, step + i)) else:
else: train_metrics = policy.update(
for i in range(num_updates): online_buffer,
train_metrics.update( step + i // cfg.utd,
policy.update( demo_buffer=offline_buffer if cfg.balanced_sampling else None,
online_buffer,
step + i // cfg.utd,
demo_buffer=offline_buffer if cfg.balanced_sampling else None,
)
) )
# Log training metrics # Log training metrics