forked from tangger/lerobot
offline training + online finetuning converge to 33 reward!
This commit is contained in:
@@ -99,10 +99,12 @@ def train(cfg: dict):
|
||||
is_offline = False
|
||||
|
||||
# TODO: use SyncDataCollector for that?
|
||||
rollout = env.rollout(
|
||||
max_steps=cfg.episode_length,
|
||||
policy=td_policy,
|
||||
)
|
||||
with torch.no_grad():
|
||||
rollout = env.rollout(
|
||||
max_steps=cfg.episode_length,
|
||||
policy=td_policy,
|
||||
auto_cast_to_device=True,
|
||||
)
|
||||
assert len(rollout) <= cfg.episode_length
|
||||
rollout["episode"] = torch.tensor(
|
||||
[online_episode_idx] * len(rollout), dtype=torch.int
|
||||
@@ -121,18 +123,14 @@ def train(cfg: dict):
|
||||
_step = min(step + len(rollout), cfg.train_steps)
|
||||
|
||||
# Update model
|
||||
train_metrics = {}
|
||||
if is_offline:
|
||||
for i in range(num_updates):
|
||||
train_metrics.update(policy.update(offline_buffer, step + i))
|
||||
else:
|
||||
for i in range(num_updates):
|
||||
train_metrics.update(
|
||||
policy.update(
|
||||
online_buffer,
|
||||
step + i // cfg.utd,
|
||||
demo_buffer=offline_buffer if cfg.balanced_sampling else None,
|
||||
)
|
||||
for i in range(num_updates):
|
||||
if is_offline:
|
||||
train_metrics = policy.update(offline_buffer, step + i)
|
||||
else:
|
||||
train_metrics = policy.update(
|
||||
online_buffer,
|
||||
step + i // cfg.utd,
|
||||
demo_buffer=offline_buffer if cfg.balanced_sampling else None,
|
||||
)
|
||||
|
||||
# Log training metrics
|
||||
|
||||
Reference in New Issue
Block a user