fix train.py, stats, eval.py (training is running)
This commit is contained in:
@@ -37,7 +37,6 @@ from pathlib import Path
|
||||
|
||||
import einops
|
||||
import gymnasium as gym
|
||||
import hydra
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -47,8 +46,8 @@ from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.logger import log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||
from lerobot.common.transforms import apply_inverse_transform
|
||||
from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||
|
||||
|
||||
def write_video(video_path, stacked_frames, fps):
|
||||
@@ -92,9 +91,12 @@ def eval_policy(
|
||||
fps: int = 15,
|
||||
return_first_video: bool = False,
|
||||
transform: callable = None,
|
||||
seed=None,
|
||||
):
|
||||
if policy is not None:
|
||||
policy.eval()
|
||||
device = "cpu" if policy is None else next(policy.parameters()).device
|
||||
|
||||
start = time.time()
|
||||
sum_rewards = []
|
||||
max_rewards = []
|
||||
@@ -125,11 +127,11 @@ def eval_policy(
|
||||
policy.reset()
|
||||
else:
|
||||
logging.warning(
|
||||
f"Policy {policy} doesnt have a `reset` method. This find if the policy doesnt rely on an internal state during rollout."
|
||||
f"Policy {policy} doesnt have a `reset` method. It is required if the policy relies on an internal state during rollout."
|
||||
)
|
||||
|
||||
# reset the environment
|
||||
observation, info = env.reset(seed=cfg.seed)
|
||||
observation, info = env.reset(seed=seed)
|
||||
maybe_render_frame(env)
|
||||
|
||||
rewards = []
|
||||
@@ -138,13 +140,12 @@ def eval_policy(
|
||||
|
||||
done = torch.tensor([False for _ in env.envs])
|
||||
step = 0
|
||||
do_rollout = True
|
||||
while do_rollout:
|
||||
while not done.all():
|
||||
# apply transform to normalize the observations
|
||||
observation = preprocess_observation(observation, transform)
|
||||
|
||||
# send observation to device/gpu
|
||||
observation = {key: observation[key].to(cfg.device, non_blocking=True) for key in observation}
|
||||
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
|
||||
|
||||
# get the next action for the environment
|
||||
with torch.inference_mode():
|
||||
@@ -180,10 +181,6 @@ def eval_policy(
|
||||
|
||||
step += 1
|
||||
|
||||
if done.all():
|
||||
do_rollout = False
|
||||
break
|
||||
|
||||
rewards = torch.stack(rewards, dim=1)
|
||||
successes = torch.stack(successes, dim=1)
|
||||
dones = torch.stack(dones, dim=1)
|
||||
@@ -295,6 +292,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||
fps=cfg.env.fps,
|
||||
# TODO(rcadene): what should we do with the transform?
|
||||
transform=dataset.transform,
|
||||
seed=cfg.seed,
|
||||
)
|
||||
print(info["aggregated"])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user