forked from tangger/lerobot
Refactor eval.py (#127)
This commit is contained in:
@@ -1,13 +1,17 @@
|
||||
import importlib
|
||||
|
||||
import gymnasium as gym
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
|
||||
"""
|
||||
Note: When `num_parallel_envs > 0`, this function returns a `SyncVectorEnv` which takes batched action as input and
|
||||
returns batched observation, reward, terminated, truncated of `num_parallel_envs` items.
|
||||
def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv:
|
||||
"""Makes a gym vector environment according to the evaluation config.
|
||||
|
||||
n_envs can be used to override eval.batch_size in the configuration. Must be at least 1.
|
||||
"""
|
||||
if n_envs is not None and n_envs < 1:
|
||||
raise ValueError("`n_envs must be at least 1")
|
||||
|
||||
kwargs = {
|
||||
"obs_type": "pixels_agent_pos",
|
||||
"render_mode": "rgb_array",
|
||||
@@ -28,16 +32,13 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
|
||||
|
||||
gym_handle = f"{package_name}/{cfg.env.task}"
|
||||
|
||||
if num_parallel_envs == 0:
|
||||
# non-batched version of the env that returns an observation of shape (c)
|
||||
env = gym.make(gym_handle, disable_env_checker=True, **kwargs)
|
||||
else:
|
||||
# batched version of the env that returns an observation of shape (b, c)
|
||||
env = gym.vector.SyncVectorEnv(
|
||||
[
|
||||
lambda: gym.make(gym_handle, disable_env_checker=True, **kwargs)
|
||||
for _ in range(num_parallel_envs)
|
||||
]
|
||||
)
|
||||
# batched version of the env that returns an observation of shape (b, c)
|
||||
env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv
|
||||
env = env_cls(
|
||||
[
|
||||
lambda: gym.make(gym_handle, disable_env_checker=True, **kwargs)
|
||||
for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size)
|
||||
]
|
||||
)
|
||||
|
||||
return env
|
||||
|
||||
@@ -1,15 +1,23 @@
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def preprocess_observation(observation):
|
||||
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
Args:
|
||||
observation: Dictionary of observation batches from a Gym vector environment.
|
||||
Returns:
|
||||
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
||||
"""
|
||||
# map to expected inputs for the policy
|
||||
obs = {}
|
||||
return_observations = {}
|
||||
|
||||
if isinstance(observation["pixels"], dict):
|
||||
imgs = {f"observation.images.{key}": img for key, img in observation["pixels"].items()}
|
||||
if isinstance(observations["pixels"], dict):
|
||||
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||
else:
|
||||
imgs = {"observation.image": observation["pixels"]}
|
||||
imgs = {"observation.image": observations["pixels"]}
|
||||
|
||||
for imgkey, img in imgs.items():
|
||||
img = torch.from_numpy(img)
|
||||
@@ -26,17 +34,10 @@ def preprocess_observation(observation):
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
|
||||
obs[imgkey] = img
|
||||
return_observations[imgkey] = img
|
||||
|
||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"
|
||||
obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float()
|
||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||
# requirement for "agent_pos"
|
||||
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
||||
|
||||
return obs
|
||||
|
||||
|
||||
def postprocess_action(action):
|
||||
action = action.to("cpu").numpy()
|
||||
assert (
|
||||
action.ndim == 2
|
||||
), "we assume dimensions are respectively the number of parallel envs, action dimensions"
|
||||
return action
|
||||
return return_observations
|
||||
|
||||
@@ -115,7 +115,7 @@ class Logger:
|
||||
for k, v in d.items():
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
|
||||
def log_video(self, video, step, mode="train"):
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
assert mode in {"train", "eval"}
|
||||
wandb_video = self._wandb.Video(video, fps=self._cfg.fps, format="mp4")
|
||||
wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4")
|
||||
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
||||
|
||||
Reference in New Issue
Block a user