fix train.py, stats, eval.py (training is running)

This commit is contained in:
Cadene
2024-04-05 09:31:39 +00:00
parent c93ce35d8c
commit 5af00d0c1e
11 changed files with 76 additions and 72 deletions

View File

@@ -37,7 +37,6 @@ from pathlib import Path
import einops
import gymnasium as gym
import hydra
import imageio
import numpy as np
import torch
@@ -47,8 +46,8 @@ from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
from lerobot.common.transforms import apply_inverse_transform
from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
def write_video(video_path, stacked_frames, fps):
@@ -92,9 +91,12 @@ def eval_policy(
fps: int = 15,
return_first_video: bool = False,
transform: callable = None,
seed=None,
):
if policy is not None:
policy.eval()
device = "cpu" if policy is None else next(policy.parameters()).device
start = time.time()
sum_rewards = []
max_rewards = []
@@ -125,11 +127,11 @@ def eval_policy(
policy.reset()
else:
logging.warning(
f"Policy {policy} doesnt have a `reset` method. This find if the policy doesnt rely on an internal state during rollout."
f"Policy {policy} doesnt have a `reset` method. It is required if the policy relies on an internal state during rollout."
)
# reset the environment
observation, info = env.reset(seed=cfg.seed)
observation, info = env.reset(seed=seed)
maybe_render_frame(env)
rewards = []
@@ -138,13 +140,12 @@ def eval_policy(
done = torch.tensor([False for _ in env.envs])
step = 0
do_rollout = True
while do_rollout:
while not done.all():
# apply transform to normalize the observations
observation = preprocess_observation(observation, transform)
# send observation to device/gpu
observation = {key: observation[key].to(cfg.device, non_blocking=True) for key in observation}
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
# get the next action for the environment
with torch.inference_mode():
@@ -180,10 +181,6 @@ def eval_policy(
step += 1
if done.all():
do_rollout = False
break
rewards = torch.stack(rewards, dim=1)
successes = torch.stack(successes, dim=1)
dones = torch.stack(dones, dim=1)
@@ -295,6 +292,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
fps=cfg.env.fps,
# TODO(rcadene): what should we do with the transform?
transform=dataset.transform,
seed=cfg.seed,
)
print(info["aggregated"])

View File

@@ -145,7 +145,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
# )
logging.info("make_env")
env = make_env(cfg)
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
logging.info("make_policy")
policy = make_policy(cfg)
@@ -173,12 +173,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
eval_info, first_video = eval_policy(
env,
policy,
num_episodes=cfg.eval_episodes,
max_steps=cfg.env.episode_length,
return_first_video=True,
video_dir=Path(out_dir) / "eval",
save_video=True,
transform=dataset.transform,
seed=cfg.seed,
)
log_eval_info(logger, eval_info["aggregated"], step, cfg, dataset, is_offline)
if cfg.wandb.enable:
@@ -211,7 +210,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = policy.update(batch, step)
train_info = policy(batch, step)
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.log_freq == 0:
@@ -223,6 +222,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
step += 1
raise NotImplementedError()
demo_buffer = dataset if cfg.policy.balanced_sampling else None
online_step = 0
is_offline = False