Refactor TD-MPC (#103)

Co-authored-by: Cadene <re.cadene@gmail.com>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Alexander Soare
2024-05-01 16:40:04 +01:00
committed by GitHub
parent a4891095e4
commit d1855a202a
17 changed files with 1105 additions and 1205 deletions

View File

@@ -67,10 +67,10 @@ def eval_policy(
"""
set `return_episode_data` to return a Hugging Face dataset object in an "episodes" key of the return dict.
"""
policy.eval()
fps = env.unwrapped.metadata["render_fps"]
if policy is not None:
policy.eval()
device = "cpu" if policy is None else next(policy.parameters()).device
start = time.time()
@@ -132,7 +132,7 @@ def eval_policy(
# get the next action for the environment
with torch.inference_mode():
action = policy.select_action(observation, step=step)
action = policy.select_action(observation)
# convert to cpu numpy
action = postprocess_action(action)
@@ -386,6 +386,7 @@ def eval(
else:
# Note: We need the dataset stats to pass to the policy's normalization modules.
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
policy.eval()
info = eval_policy(
env,