fix train.py, stats, eval.py (training is running)

This commit is contained in:
Cadene
2024-04-05 09:31:39 +00:00
parent c93ce35d8c
commit 5af00d0c1e
11 changed files with 76 additions and 72 deletions

View File

@@ -37,7 +37,6 @@ from pathlib import Path
import einops
import gymnasium as gym
import hydra
import imageio
import numpy as np
import torch
@@ -47,8 +46,8 @@ from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
from lerobot.common.transforms import apply_inverse_transform
from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
def write_video(video_path, stacked_frames, fps):
@@ -92,9 +91,12 @@ def eval_policy(
fps: int = 15,
return_first_video: bool = False,
transform: callable = None,
seed=None,
):
if policy is not None:
policy.eval()
device = "cpu" if policy is None else next(policy.parameters()).device
start = time.time()
sum_rewards = []
max_rewards = []
@@ -125,11 +127,11 @@ def eval_policy(
policy.reset()
else:
logging.warning(
f"Policy {policy} doesnt have a `reset` method. This find if the policy doesnt rely on an internal state during rollout."
f"Policy {policy} doesnt have a `reset` method. It is required if the policy relies on an internal state during rollout."
)
# reset the environment
observation, info = env.reset(seed=cfg.seed)
observation, info = env.reset(seed=seed)
maybe_render_frame(env)
rewards = []
@@ -138,13 +140,12 @@ def eval_policy(
done = torch.tensor([False for _ in env.envs])
step = 0
do_rollout = True
while do_rollout:
while not done.all():
# apply transform to normalize the observations
observation = preprocess_observation(observation, transform)
# send observation to device/gpu
observation = {key: observation[key].to(cfg.device, non_blocking=True) for key in observation}
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
# get the next action for the environment
with torch.inference_mode():
@@ -180,10 +181,6 @@ def eval_policy(
step += 1
if done.all():
do_rollout = False
break
rewards = torch.stack(rewards, dim=1)
successes = torch.stack(successes, dim=1)
dones = torch.stack(dones, dim=1)
@@ -295,6 +292,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
fps=cfg.env.fps,
# TODO(rcadene): what should we do with the transform?
transform=dataset.transform,
seed=cfg.seed,
)
print(info["aggregated"])