Sanitize cfg.env
This commit is contained in:
@@ -88,7 +88,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
|
||||
while step < cfg.train_steps:
|
||||
is_offline = True
|
||||
num_updates = cfg.episode_length
|
||||
num_updates = cfg.env.episode_length
|
||||
_step = step + num_updates
|
||||
rollout_metrics = {}
|
||||
|
||||
@@ -98,11 +98,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
# TODO: use SyncDataCollector for that?
|
||||
with torch.no_grad():
|
||||
rollout = env.rollout(
|
||||
max_steps=cfg.episode_length,
|
||||
max_steps=cfg.env.episode_length,
|
||||
policy=td_policy,
|
||||
auto_cast_to_device=True,
|
||||
)
|
||||
assert len(rollout) <= cfg.episode_length
|
||||
assert len(rollout) <= cfg.env.episode_length
|
||||
rollout["episode"] = torch.tensor(
|
||||
[online_episode_idx] * len(rollout), dtype=torch.int
|
||||
)
|
||||
@@ -133,7 +133,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
)
|
||||
|
||||
# Log training metrics
|
||||
env_step = int(_step * cfg.action_repeat)
|
||||
env_step = int(_step * cfg.env.action_repeat)
|
||||
common_metrics = {
|
||||
"episode": online_episode_idx,
|
||||
"step": _step,
|
||||
|
||||
Reference in New Issue
Block a user