fix train.py, stats, eval.py (training is running)

This commit is contained in:
Cadene
2024-04-05 09:31:39 +00:00
parent c93ce35d8c
commit 5af00d0c1e
11 changed files with 76 additions and 72 deletions

View File

@@ -145,7 +145,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
# )
logging.info("make_env")
env = make_env(cfg)
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
logging.info("make_policy")
policy = make_policy(cfg)
@@ -173,12 +173,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
eval_info, first_video = eval_policy(
env,
policy,
num_episodes=cfg.eval_episodes,
max_steps=cfg.env.episode_length,
return_first_video=True,
video_dir=Path(out_dir) / "eval",
save_video=True,
transform=dataset.transform,
seed=cfg.seed,
)
log_eval_info(logger, eval_info["aggregated"], step, cfg, dataset, is_offline)
if cfg.wandb.enable:
@@ -211,7 +210,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = policy.update(batch, step)
train_info = policy(batch, step)
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.log_freq == 0:
@@ -223,6 +222,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
step += 1
raise NotImplementedError()
demo_buffer = dataset if cfg.policy.balanced_sampling else None
online_step = 0
is_offline = False