forked from tangger/lerobot
Refactor TD-MPC (#103)
Co-authored-by: Cadene <re.cadene@gmail.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
@@ -67,10 +67,10 @@ def eval_policy(
|
||||
"""
|
||||
set `return_episode_data` to return a Hugging Face dataset object in an "episodes" key of the return dict.
|
||||
"""
|
||||
policy.eval()
|
||||
|
||||
fps = env.unwrapped.metadata["render_fps"]
|
||||
|
||||
if policy is not None:
|
||||
policy.eval()
|
||||
device = "cpu" if policy is None else next(policy.parameters()).device
|
||||
|
||||
start = time.time()
|
||||
@@ -132,7 +132,7 @@ def eval_policy(
|
||||
|
||||
# get the next action for the environment
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation, step=step)
|
||||
action = policy.select_action(observation)
|
||||
|
||||
# convert to cpu numpy
|
||||
action = postprocess_action(action)
|
||||
@@ -386,6 +386,7 @@ def eval(
|
||||
else:
|
||||
# Note: We need the dataset stats to pass to the policy's normalization modules.
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
|
||||
policy.eval()
|
||||
|
||||
info = eval_policy(
|
||||
env,
|
||||
|
||||
Reference in New Issue
Block a user