From 265b0ec44d42ea8cb912534df5b1e818d6caec0d Mon Sep 17 00:00:00 2001 From: Remi Date: Thu, 30 May 2024 13:45:22 +0200 Subject: [PATCH] Refactor env to add key word arguments from config yaml (#223) Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> --- .../push_dataset_to_hub/aloha_dora_format.py | 18 +++++++----- lerobot/common/envs/factory.py | 14 ++++----- lerobot/configs/default.yaml | 2 ++ lerobot/configs/env/aloha.yaml | 10 +++---- lerobot/configs/env/pusht.yaml | 11 ++++--- lerobot/configs/env/xarm.yaml | 11 ++++--- lerobot/scripts/train.py | 29 +++++++++++++++---- 7 files changed, 59 insertions(+), 36 deletions(-) diff --git a/lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py b/lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py index d1e5a52..4a21bc2 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py @@ -69,12 +69,10 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int): # "nearest" is the best option over "backward", since the latter can desynchronizes camera timestamps by # matching timestamps that are too far appart, in order to fit the backward constraints. It's not the case for "nearest". # However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps. - # This is not a problem when the tolerance is set to be low enough to avoid matching timestamps that # are too far appart. direction="nearest", tolerance=pd.Timedelta(f"{1/fps} seconds"), ) - # Remove rows with episode_index -1 which indicates data that correspond to in-between episodes df = df[df["episode_index"] != -1] @@ -89,9 +87,10 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int): raise ValueError(path) episode_index = int(match.group(1)) episode_index_per_cam[key] = episode_index - assert ( - len(set(episode_index_per_cam.values())) == 1 - ), f"All cameras are expected to belong to the same episode, but getting {episode_index_per_cam}" + if len(set(episode_index_per_cam.values())) != 1: + raise ValueError( + f"All cameras are expected to belong to the same episode, but getting {episode_index_per_cam}" + ) return episode_index df["episode_index"] = df.apply(get_episode_index, axis=1) @@ -119,7 +118,8 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int): # sanity check episode indices go from 0 to n-1 ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")] expected_ep_ids = list(range(df["episode_index"].max() + 1)) - assert ep_ids == expected_ep_ids, f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}" + if ep_ids != expected_ep_ids: + raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}") # Create symlink to raw videos directory (that needs to be absolute not relative) out_dir.mkdir(parents=True, exist_ok=True) @@ -132,7 +132,8 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int): continue for ep_idx in ep_ids: video_path = videos_dir / f"{key}_episode_{ep_idx:06d}.mp4" - assert video_path.exists(), f"Video file not found in {video_path}" + if not video_path.exists(): + raise ValueError(f"Video file not found in {video_path}") data_dict = {} for key in df: @@ -144,7 +145,8 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int): # sanity check the video path is well formated video_path = videos_dir.parent / data_dict[key][0]["path"] - assert video_path.exists(), f"Video file not found in {video_path}" + if not video_path.exists(): + raise ValueError(f"Video file not found in {video_path}") # is number elif df[key].iloc[0].ndim == 0 or df[key].iloc[0].shape[0] == 1: data_dict[key] = torch.from_numpy(df[key].values) diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 83f94cf..d73939b 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -27,14 +27,6 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv if n_envs is not None and n_envs < 1: raise ValueError("`n_envs must be at least 1") - kwargs = { - "obs_type": "pixels_agent_pos", - "render_mode": "rgb_array", - "max_episode_steps": cfg.env.episode_length, - "visualization_width": 384, - "visualization_height": 384, - } - package_name = f"gym_{cfg.env.name}" try: @@ -46,12 +38,16 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv raise e gym_handle = f"{package_name}/{cfg.env.task}" + gym_kwgs = dict(cfg.env.get("gym", {})) + + if cfg.env.get("episode_length"): + gym_kwgs["max_episode_steps"] = cfg.env.episode_length # batched version of the env that returns an observation of shape (b, c) env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv env = env_cls( [ - lambda: gym.make(gym_handle, disable_env_checker=True, **kwargs) + lambda: gym.make(gym_handle, disable_env_checker=True, **gym_kwgs) for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size) ] ) diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 9ae3078..f223876 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -37,6 +37,8 @@ training: save_freq: ??? log_freq: 250 save_checkpoint: true + num_workers: 4 + batch_size: ??? eval: n_episodes: 1 diff --git a/lerobot/configs/env/aloha.yaml b/lerobot/configs/env/aloha.yaml index 95e4503..296a448 100644 --- a/lerobot/configs/env/aloha.yaml +++ b/lerobot/configs/env/aloha.yaml @@ -5,10 +5,10 @@ fps: 50 env: name: aloha task: AlohaInsertion-v0 - from_pixels: True - pixels_only: False - image_size: [3, 480, 640] - episode_length: 400 - fps: ${fps} state_dim: 14 action_dim: 14 + fps: ${fps} + episode_length: 400 + gym: + obs_type: pixels_agent_pos + render_mode: rgb_array diff --git a/lerobot/configs/env/pusht.yaml b/lerobot/configs/env/pusht.yaml index 43e9d18..771fbbf 100644 --- a/lerobot/configs/env/pusht.yaml +++ b/lerobot/configs/env/pusht.yaml @@ -5,10 +5,13 @@ fps: 10 env: name: pusht task: PushT-v0 - from_pixels: True - pixels_only: False image_size: 96 - episode_length: 300 - fps: ${fps} state_dim: 2 action_dim: 2 + fps: ${fps} + episode_length: 300 + gym: + obs_type: pixels_agent_pos + render_mode: rgb_array + visualization_width: 384 + visualization_height: 384 diff --git a/lerobot/configs/env/xarm.yaml b/lerobot/configs/env/xarm.yaml index 098b039..9dbb96f 100644 --- a/lerobot/configs/env/xarm.yaml +++ b/lerobot/configs/env/xarm.yaml @@ -5,10 +5,13 @@ fps: 15 env: name: xarm task: XarmLift-v0 - from_pixels: True - pixels_only: False image_size: 84 - episode_length: 25 - fps: ${fps} state_dim: 4 action_dim: 4 + fps: ${fps} + episode_length: 25 + gym: + obs_type: pixels_agent_pos + render_mode: rgb_array + visualization_width: 384 + visualization_height: 384 diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 5fb86f3..eb33b26 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -281,8 +281,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No logging.info("make_dataset") offline_dataset = make_dataset(cfg) - logging.info("make_env") - eval_env = make_env(cfg) + # Create environment used for evaluating checkpoints during training on simulation data. + # On real-world data, no need to create an environment as evaluations are done outside train.py, + # using the eval.py instead, with gym_dora environment and dora-rs. + if cfg.training.eval_freq > 0: + logging.info("make_env") + eval_env = make_env(cfg) logging.info("make_policy") policy = make_policy( @@ -315,7 +319,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # Note: this helper will be used in offline and online training loops. def evaluate_and_checkpoint_if_needed(step): - if step % cfg.training.eval_freq == 0: + if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0: logging.info(f"Eval policy at step {step}") with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext(): eval_info = eval_policy( @@ -349,7 +353,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # create dataloader for offline training dataloader = torch.utils.data.DataLoader( offline_dataset, - num_workers=4, + num_workers=cfg.training.num_workers, batch_size=cfg.training.batch_size, shuffle=True, pin_memory=device.type != "cpu", @@ -386,6 +390,16 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No step += 1 + logging.info("End of offline training") + + if cfg.training.online_steps == 0: + if cfg.training.eval_freq > 0: + eval_env.close() + return + + # create an env dedicated to online episodes collection from policy rollout + online_training_env = make_env(cfg, n_envs=1) + # create an empty online dataset similar to offline dataset online_dataset = deepcopy(offline_dataset) online_dataset.hf_dataset = {} @@ -406,8 +420,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No drop_last=False, ) - eval_env.close() - logging.info("End of training") + logging.info("End of online training") + + if cfg.training.eval_freq > 0: + eval_env.close() + online_training_env.close() @hydra.main(version_base="1.2", config_name="default", config_path="../configs")