offline training + online finetuning converge to 33 reward!

This commit is contained in:
Cadene
2024-02-18 01:23:44 +00:00
parent 0b4084f0f8
commit a5c305a7a4
3 changed files with 42 additions and 36 deletions

View File

@@ -129,19 +129,17 @@ class TDMPC(nn.Module):
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
if isinstance(obs, dict):
obs = {
k: torch.tensor(o, dtype=torch.float32, device=self.device).unsqueeze(0)
k: o.detach().unsqueeze(0)
for k, o in obs.items()
}
else:
obs = torch.tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(
0
)
obs = obs.detach().unsqueeze(0)
z = self.model.encode(obs)
if self.cfg.mpc:
a = self.plan(z, t0=t0, step=step)
else:
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
return a.cpu()
return a
@torch.no_grad()
def estimate_value(self, z, actions, horizon):
@@ -324,7 +322,7 @@ class TDMPC(nn.Module):
# trajectory t = 256, horizon h = 5
# (t h) ... -> h t ...
batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
batch = batch.to("cuda")
batch = batch.to(self.device)
FIRST_FRAME = 0
obs = {
@@ -469,7 +467,11 @@ class TDMPC(nn.Module):
weighted_loss = (total_loss.squeeze(1) * weights).mean()
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
weighted_loss.backward()
has_nan = torch.isnan(weighted_loss).item()
if has_nan:
print(f"weighted_loss has nan: {total_loss=} {weights=}")
else:
weighted_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
@@ -479,12 +481,16 @@ class TDMPC(nn.Module):
if self.cfg.per:
# Update priorities
priorities = priority_loss.clamp(max=1e4).detach()
replay_buffer.update_priority(
idxs[:num_slices],
priorities[:num_slices],
)
if demo_batch_size > 0:
demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
has_nan = torch.isnan(priorities).any().item()
if has_nan:
print(f"priorities has nan: {priorities=}")
else:
replay_buffer.update_priority(
idxs[:num_slices],
priorities[:num_slices],
)
if demo_batch_size > 0:
demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
# Update policy + target network
_, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action)

View File

@@ -36,13 +36,15 @@ def eval_policy(
# render first frame before rollout
rendering_callback(env)
rollout = env.rollout(
max_steps=max_steps,
policy=policy,
callback=rendering_callback if save_video else None,
auto_reset=False,
tensordict=tensordict,
)
with torch.inference_mode():
rollout = env.rollout(
max_steps=max_steps,
policy=policy,
callback=rendering_callback if save_video else None,
auto_reset=False,
tensordict=tensordict,
auto_cast_to_device=True,
)
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
ep_reward = rollout["next", "reward"].sum()
ep_success = rollout["next", "success"].any()

View File

@@ -99,10 +99,12 @@ def train(cfg: dict):
is_offline = False
# TODO: use SyncDataCollector for that?
rollout = env.rollout(
max_steps=cfg.episode_length,
policy=td_policy,
)
with torch.no_grad():
rollout = env.rollout(
max_steps=cfg.episode_length,
policy=td_policy,
auto_cast_to_device=True,
)
assert len(rollout) <= cfg.episode_length
rollout["episode"] = torch.tensor(
[online_episode_idx] * len(rollout), dtype=torch.int
@@ -121,18 +123,14 @@ def train(cfg: dict):
_step = min(step + len(rollout), cfg.train_steps)
# Update model
train_metrics = {}
if is_offline:
for i in range(num_updates):
train_metrics.update(policy.update(offline_buffer, step + i))
else:
for i in range(num_updates):
train_metrics.update(
policy.update(
online_buffer,
step + i // cfg.utd,
demo_buffer=offline_buffer if cfg.balanced_sampling else None,
)
for i in range(num_updates):
if is_offline:
train_metrics = policy.update(offline_buffer, step + i)
else:
train_metrics = policy.update(
online_buffer,
step + i // cfg.utd,
demo_buffer=offline_buffer if cfg.balanced_sampling else None,
)
# Log training metrics