Sanitize cfg.env

This commit is contained in:
Cadene
2024-02-25 12:02:29 +00:00
parent 9b469c4232
commit ed80db2846
6 changed files with 46 additions and 42 deletions

View File

@@ -88,7 +88,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
while step < cfg.train_steps:
is_offline = True
num_updates = cfg.episode_length
num_updates = cfg.env.episode_length
_step = step + num_updates
rollout_metrics = {}
@@ -98,11 +98,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
# TODO: use SyncDataCollector for that?
with torch.no_grad():
rollout = env.rollout(
max_steps=cfg.episode_length,
max_steps=cfg.env.episode_length,
policy=td_policy,
auto_cast_to_device=True,
)
assert len(rollout) <= cfg.episode_length
assert len(rollout) <= cfg.env.episode_length
rollout["episode"] = torch.tensor(
[online_episode_idx] * len(rollout), dtype=torch.int
)
@@ -133,7 +133,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
)
# Log training metrics
env_step = int(_step * cfg.action_repeat)
env_step = int(_step * cfg.env.action_repeat)
common_metrics = {
"episode": online_episode_idx,
"step": _step,