Merge remote-tracking branch 'origin/user/rcadene/2024_03_31_remove_torchrl' into user/rcadene/2024_03_31_remove_torchrl

This commit is contained in:
Cadene
2024-04-10 11:34:51 +00:00
19 changed files with 1082 additions and 1805 deletions

View File

@@ -307,7 +307,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
logging.info("Making transforms.")
# TODO(alexander-soare): Completely decouple datasets from evaluation.
dataset = make_dataset(cfg, stats_path=stats_path)
transform = make_dataset(cfg, stats_path=stats_path).transform
logging.info("Making environment.")
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
@@ -322,7 +322,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
video_dir=Path(out_dir) / "eval",
fps=cfg.env.fps,
# TODO(rcadene): what should we do with the transform?
transform=dataset.transform,
transform=transform,
seed=cfg.seed,
)
print(info["aggregated"])

View File

@@ -41,7 +41,6 @@ def log_train_info(logger, info, step, cfg, dataset, is_offline):
loss = info["loss"]
grad_norm = info["grad_norm"]
lr = info["lr"]
data_s = info["data_s"]
update_s = info["update_s"]
# A sample is an (observation,action) pair, where observation and action
@@ -62,7 +61,6 @@ def log_train_info(logger, info, step, cfg, dataset, is_offline):
f"grdn:{grad_norm:.3f}",
f"lr:{lr:0.1e}",
# in seconds
f"data_s:{data_s:.3f}",
f"updt_s:{update_s:.3f}",
]
logging.info(" ".join(log_items))