offline training + online finetuning converge to 33 reward!

This commit is contained in:
Cadene
2024-02-18 01:23:44 +00:00
parent 0b4084f0f8
commit a5c305a7a4
3 changed files with 42 additions and 36 deletions

View File

@@ -99,10 +99,12 @@ def train(cfg: dict):
is_offline = False
# TODO: use SyncDataCollector for that?
rollout = env.rollout(
max_steps=cfg.episode_length,
policy=td_policy,
)
with torch.no_grad():
rollout = env.rollout(
max_steps=cfg.episode_length,
policy=td_policy,
auto_cast_to_device=True,
)
assert len(rollout) <= cfg.episode_length
rollout["episode"] = torch.tensor(
[online_episode_idx] * len(rollout), dtype=torch.int
@@ -121,18 +123,14 @@ def train(cfg: dict):
_step = min(step + len(rollout), cfg.train_steps)
# Update model
train_metrics = {}
if is_offline:
for i in range(num_updates):
train_metrics.update(policy.update(offline_buffer, step + i))
else:
for i in range(num_updates):
train_metrics.update(
policy.update(
online_buffer,
step + i // cfg.utd,
demo_buffer=offline_buffer if cfg.balanced_sampling else None,
)
for i in range(num_updates):
if is_offline:
train_metrics = policy.update(offline_buffer, step + i)
else:
train_metrics = policy.update(
online_buffer,
step + i // cfg.utd,
demo_buffer=offline_buffer if cfg.balanced_sampling else None,
)
# Log training metrics