offline training + online finetuning converge to 33 reward!
This commit is contained in:
@@ -129,19 +129,17 @@ class TDMPC(nn.Module):
|
||||
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
|
||||
if isinstance(obs, dict):
|
||||
obs = {
|
||||
k: torch.tensor(o, dtype=torch.float32, device=self.device).unsqueeze(0)
|
||||
k: o.detach().unsqueeze(0)
|
||||
for k, o in obs.items()
|
||||
}
|
||||
else:
|
||||
obs = torch.tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(
|
||||
0
|
||||
)
|
||||
obs = obs.detach().unsqueeze(0)
|
||||
z = self.model.encode(obs)
|
||||
if self.cfg.mpc:
|
||||
a = self.plan(z, t0=t0, step=step)
|
||||
else:
|
||||
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
|
||||
return a.cpu()
|
||||
return a
|
||||
|
||||
@torch.no_grad()
|
||||
def estimate_value(self, z, actions, horizon):
|
||||
@@ -324,7 +322,7 @@ class TDMPC(nn.Module):
|
||||
# trajectory t = 256, horizon h = 5
|
||||
# (t h) ... -> h t ...
|
||||
batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
|
||||
batch = batch.to("cuda")
|
||||
batch = batch.to(self.device)
|
||||
|
||||
FIRST_FRAME = 0
|
||||
obs = {
|
||||
@@ -469,7 +467,11 @@ class TDMPC(nn.Module):
|
||||
|
||||
weighted_loss = (total_loss.squeeze(1) * weights).mean()
|
||||
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
|
||||
weighted_loss.backward()
|
||||
has_nan = torch.isnan(weighted_loss).item()
|
||||
if has_nan:
|
||||
print(f"weighted_loss has nan: {total_loss=} {weights=}")
|
||||
else:
|
||||
weighted_loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
|
||||
@@ -479,12 +481,16 @@ class TDMPC(nn.Module):
|
||||
if self.cfg.per:
|
||||
# Update priorities
|
||||
priorities = priority_loss.clamp(max=1e4).detach()
|
||||
replay_buffer.update_priority(
|
||||
idxs[:num_slices],
|
||||
priorities[:num_slices],
|
||||
)
|
||||
if demo_batch_size > 0:
|
||||
demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
|
||||
has_nan = torch.isnan(priorities).any().item()
|
||||
if has_nan:
|
||||
print(f"priorities has nan: {priorities=}")
|
||||
else:
|
||||
replay_buffer.update_priority(
|
||||
idxs[:num_slices],
|
||||
priorities[:num_slices],
|
||||
)
|
||||
if demo_batch_size > 0:
|
||||
demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
|
||||
|
||||
# Update policy + target network
|
||||
_, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action)
|
||||
|
||||
@@ -36,13 +36,15 @@ def eval_policy(
|
||||
# render first frame before rollout
|
||||
rendering_callback(env)
|
||||
|
||||
rollout = env.rollout(
|
||||
max_steps=max_steps,
|
||||
policy=policy,
|
||||
callback=rendering_callback if save_video else None,
|
||||
auto_reset=False,
|
||||
tensordict=tensordict,
|
||||
)
|
||||
with torch.inference_mode():
|
||||
rollout = env.rollout(
|
||||
max_steps=max_steps,
|
||||
policy=policy,
|
||||
callback=rendering_callback if save_video else None,
|
||||
auto_reset=False,
|
||||
tensordict=tensordict,
|
||||
auto_cast_to_device=True,
|
||||
)
|
||||
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
|
||||
ep_reward = rollout["next", "reward"].sum()
|
||||
ep_success = rollout["next", "success"].any()
|
||||
|
||||
@@ -99,10 +99,12 @@ def train(cfg: dict):
|
||||
is_offline = False
|
||||
|
||||
# TODO: use SyncDataCollector for that?
|
||||
rollout = env.rollout(
|
||||
max_steps=cfg.episode_length,
|
||||
policy=td_policy,
|
||||
)
|
||||
with torch.no_grad():
|
||||
rollout = env.rollout(
|
||||
max_steps=cfg.episode_length,
|
||||
policy=td_policy,
|
||||
auto_cast_to_device=True,
|
||||
)
|
||||
assert len(rollout) <= cfg.episode_length
|
||||
rollout["episode"] = torch.tensor(
|
||||
[online_episode_idx] * len(rollout), dtype=torch.int
|
||||
@@ -121,18 +123,14 @@ def train(cfg: dict):
|
||||
_step = min(step + len(rollout), cfg.train_steps)
|
||||
|
||||
# Update model
|
||||
train_metrics = {}
|
||||
if is_offline:
|
||||
for i in range(num_updates):
|
||||
train_metrics.update(policy.update(offline_buffer, step + i))
|
||||
else:
|
||||
for i in range(num_updates):
|
||||
train_metrics.update(
|
||||
policy.update(
|
||||
online_buffer,
|
||||
step + i // cfg.utd,
|
||||
demo_buffer=offline_buffer if cfg.balanced_sampling else None,
|
||||
)
|
||||
for i in range(num_updates):
|
||||
if is_offline:
|
||||
train_metrics = policy.update(offline_buffer, step + i)
|
||||
else:
|
||||
train_metrics = policy.update(
|
||||
online_buffer,
|
||||
step + i // cfg.utd,
|
||||
demo_buffer=offline_buffer if cfg.balanced_sampling else None,
|
||||
)
|
||||
|
||||
# Log training metrics
|
||||
|
||||
Reference in New Issue
Block a user