forked from tangger/lerobot
fix train.py, stats, eval.py (training is running)
This commit is contained in:
@@ -145,7 +145,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
# )
|
||||
|
||||
logging.info("make_env")
|
||||
env = make_env(cfg)
|
||||
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||
|
||||
logging.info("make_policy")
|
||||
policy = make_policy(cfg)
|
||||
@@ -173,12 +173,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
eval_info, first_video = eval_policy(
|
||||
env,
|
||||
policy,
|
||||
num_episodes=cfg.eval_episodes,
|
||||
max_steps=cfg.env.episode_length,
|
||||
return_first_video=True,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
save_video=True,
|
||||
transform=dataset.transform,
|
||||
seed=cfg.seed,
|
||||
)
|
||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, dataset, is_offline)
|
||||
if cfg.wandb.enable:
|
||||
@@ -211,7 +210,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||
|
||||
train_info = policy.update(batch, step)
|
||||
train_info = policy(batch, step)
|
||||
|
||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||
if step % cfg.log_freq == 0:
|
||||
@@ -223,6 +222,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
|
||||
step += 1
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
demo_buffer = dataset if cfg.policy.balanced_sampling else None
|
||||
online_step = 0
|
||||
is_offline = False
|
||||
|
||||
Reference in New Issue
Block a user