offline training + online finetuning converge to 33 reward!
This commit is contained in:
@@ -36,13 +36,15 @@ def eval_policy(
|
||||
# render first frame before rollout
|
||||
rendering_callback(env)
|
||||
|
||||
rollout = env.rollout(
|
||||
max_steps=max_steps,
|
||||
policy=policy,
|
||||
callback=rendering_callback if save_video else None,
|
||||
auto_reset=False,
|
||||
tensordict=tensordict,
|
||||
)
|
||||
with torch.inference_mode():
|
||||
rollout = env.rollout(
|
||||
max_steps=max_steps,
|
||||
policy=policy,
|
||||
callback=rendering_callback if save_video else None,
|
||||
auto_reset=False,
|
||||
tensordict=tensordict,
|
||||
auto_cast_to_device=True,
|
||||
)
|
||||
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
|
||||
ep_reward = rollout["next", "reward"].sum()
|
||||
ep_success = rollout["next", "success"].any()
|
||||
|
||||
@@ -99,10 +99,12 @@ def train(cfg: dict):
|
||||
is_offline = False
|
||||
|
||||
# TODO: use SyncDataCollector for that?
|
||||
rollout = env.rollout(
|
||||
max_steps=cfg.episode_length,
|
||||
policy=td_policy,
|
||||
)
|
||||
with torch.no_grad():
|
||||
rollout = env.rollout(
|
||||
max_steps=cfg.episode_length,
|
||||
policy=td_policy,
|
||||
auto_cast_to_device=True,
|
||||
)
|
||||
assert len(rollout) <= cfg.episode_length
|
||||
rollout["episode"] = torch.tensor(
|
||||
[online_episode_idx] * len(rollout), dtype=torch.int
|
||||
@@ -121,18 +123,14 @@ def train(cfg: dict):
|
||||
_step = min(step + len(rollout), cfg.train_steps)
|
||||
|
||||
# Update model
|
||||
train_metrics = {}
|
||||
if is_offline:
|
||||
for i in range(num_updates):
|
||||
train_metrics.update(policy.update(offline_buffer, step + i))
|
||||
else:
|
||||
for i in range(num_updates):
|
||||
train_metrics.update(
|
||||
policy.update(
|
||||
online_buffer,
|
||||
step + i // cfg.utd,
|
||||
demo_buffer=offline_buffer if cfg.balanced_sampling else None,
|
||||
)
|
||||
for i in range(num_updates):
|
||||
if is_offline:
|
||||
train_metrics = policy.update(offline_buffer, step + i)
|
||||
else:
|
||||
train_metrics = policy.update(
|
||||
online_buffer,
|
||||
step + i // cfg.utd,
|
||||
demo_buffer=offline_buffer if cfg.balanced_sampling else None,
|
||||
)
|
||||
|
||||
# Log training metrics
|
||||
|
||||
Reference in New Issue
Block a user