From 7bf36cd4139ec28bf6f3df55f5561efe89ef646c Mon Sep 17 00:00:00 2001 From: Cadene Date: Sun, 10 Mar 2024 22:00:48 +0000 Subject: [PATCH] Add AbstractEnv, Refactor AlohaEnv, Add rendering_hook in env, Minor modifications, (TODO: Refactor Pusht and Simxarm) --- lerobot/common/datasets/aloha.py | 3 - lerobot/common/envs/abstract.py | 75 +++++++++++++++++++++++++ lerobot/common/envs/aloha/env.py | 46 ++++++--------- lerobot/common/policies/act/detr_vae.py | 8 ++- lerobot/common/policies/act/policy.py | 20 ++++--- lerobot/configs/env/aloha.yaml | 2 +- lerobot/configs/policy/act.yaml | 7 ++- lerobot/scripts/eval.py | 15 ++--- lerobot/scripts/train.py | 3 + sbatch.sh | 3 +- tests/test_datasets.py | 8 +-- 11 files changed, 131 insertions(+), 59 deletions(-) diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 95a334500..851cc75b2 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -124,9 +124,6 @@ class AlohaExperienceReplay(AbstractExperienceReplay): def image_keys(self) -> list: return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]] - # def _is_downloaded(self) -> bool: - # return False - def _download_and_preproc(self): raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw" if not raw_dir.is_dir(): diff --git a/lerobot/common/envs/abstract.py b/lerobot/common/envs/abstract.py index e69de29bb..2901e4d2b 100644 --- a/lerobot/common/envs/abstract.py +++ b/lerobot/common/envs/abstract.py @@ -0,0 +1,75 @@ +import abc +from collections import deque +from typing import Optional + +from tensordict import TensorDict +from torchrl.envs import EnvBase + + +class AbstractEnv(EnvBase): + def __init__( + self, + task, + frame_skip: int = 1, + from_pixels: bool = False, + pixels_only: bool = False, + image_size=None, + seed=1337, + device="cpu", + num_prev_obs=1, + num_prev_action=0, + ): + super().__init__(device=device, batch_size=[]) + self.task = task + self.frame_skip = frame_skip + self.from_pixels = from_pixels + self.pixels_only = pixels_only + self.image_size = image_size + self.num_prev_obs = num_prev_obs + self.num_prev_action = num_prev_action + self._rendering_hooks = [] + + if pixels_only: + assert from_pixels + if from_pixels: + assert image_size + + self._make_spec() + self._current_seed = self.set_seed(seed) + + if self.num_prev_obs > 0: + self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs) + self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs) + if self.num_prev_action > 0: + raise NotImplementedError() + # self._prev_action_queue = deque(maxlen=self.num_prev_action) + + def register_rendering_hook(self, func): + self._rendering_hooks.append(func) + + def call_rendering_hooks(self): + for func in self._rendering_hooks: + func(self) + + def reset_rendering_hooks(self): + self._rendering_hooks = [] + + @abc.abstractmethod + def render(self, mode="rgb_array", width=640, height=480): + raise NotImplementedError() + + @abc.abstractmethod + def _reset(self, tensordict: Optional[TensorDict] = None): + raise NotImplementedError() + + @abc.abstractmethod + def _step(self, tensordict: TensorDict): + raise NotImplementedError() + + @abc.abstractmethod + def _make_spec(self): + raise NotImplementedError() + + @abc.abstractmethod + def _set_seed(self, seed: Optional[int]): + raise NotImplementedError() diff --git a/lerobot/common/envs/aloha/env.py b/lerobot/common/envs/aloha/env.py index acb30b325..f0cbb25df 100644 --- a/lerobot/common/envs/aloha/env.py +++ b/lerobot/common/envs/aloha/env.py @@ -15,8 +15,8 @@ from torchrl.data.tensor_specs import ( DiscreteTensorSpec, UnboundedContinuousTensorSpec, ) -from torchrl.envs import EnvBase +from lerobot.common.envs.abstract import AbstractEnv from lerobot.common.envs.aloha.constants import ( ACTIONS, ASSETS_DIR, @@ -28,14 +28,13 @@ from lerobot.common.envs.aloha.tasks.sim_end_effector import ( InsertionEndEffectorTask, TransferCubeEndEffectorTask, ) +from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose from lerobot.common.utils import set_seed -from .utils import sample_box_pose, sample_insertion_pose - _has_gym = importlib.util.find_spec("gym") is not None -class AlohaEnv(EnvBase): +class AlohaEnv(AbstractEnv): def __init__( self, task, @@ -48,20 +47,17 @@ class AlohaEnv(EnvBase): num_prev_obs=1, num_prev_action=0, ): - super().__init__(device=device, batch_size=[]) - self.task = task - self.frame_skip = frame_skip - self.from_pixels = from_pixels - self.pixels_only = pixels_only - self.image_size = image_size - self.num_prev_obs = num_prev_obs - self.num_prev_action = num_prev_action - - if pixels_only: - assert from_pixels - if from_pixels: - assert image_size - + super().__init__( + task=task, + frame_skip=frame_skip, + from_pixels=from_pixels, + pixels_only=pixels_only, + image_size=image_size, + seed=seed, + device=device, + num_prev_obs=num_prev_obs, + num_prev_action=num_prev_action, + ) if not _has_gym: raise ImportError("Cannot import gym.") @@ -70,16 +66,6 @@ class AlohaEnv(EnvBase): self._env = self._make_env_task(task) - self._make_spec() - self._current_seed = self.set_seed(seed) - - if self.num_prev_obs > 0: - self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs) - self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs) - if self.num_prev_action > 0: - raise NotImplementedError() - # self._prev_action_queue = deque(maxlen=self.num_prev_action) - def render(self, mode="rgb_array", width=640, height=480): # TODO(rcadene): render and visualizer several cameras (e.g. angle, front_close) image = self._env.physics.render(height=height, width=width, camera_id="top") @@ -172,6 +158,8 @@ class AlohaEnv(EnvBase): ) else: raise NotImplementedError() + + self.call_rendering_hooks() return td def _step(self, tensordict: TensorDict): @@ -207,6 +195,8 @@ class AlohaEnv(EnvBase): stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) obs = stacked_obs + self.call_rendering_hooks() + td = TensorDict( { "observation": TensorDict(obs, batch_size=[]), diff --git a/lerobot/common/policies/act/detr_vae.py b/lerobot/common/policies/act/detr_vae.py index 2c9704308..0f2626f74 100644 --- a/lerobot/common/policies/act/detr_vae.py +++ b/lerobot/common/policies/act/detr_vae.py @@ -27,7 +27,9 @@ def get_sinusoid_encoding_table(n_position, d_hid): class DETRVAE(nn.Module): """This is the DETR module that performs object detection""" - def __init__(self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names): + def __init__( + self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names, vae + ): """Initializes the model. Parameters: backbones: torch module of the backbone to be used. See backbone.py @@ -42,6 +44,7 @@ class DETRVAE(nn.Module): self.camera_names = camera_names self.transformer = transformer self.encoder = encoder + self.vae = vae hidden_dim = transformer.d_model self.action_head = nn.Linear(hidden_dim, action_dim) self.is_pad_head = nn.Linear(hidden_dim, 1) @@ -86,7 +89,7 @@ class DETRVAE(nn.Module): is_training = actions is not None # train or val bs, _ = qpos.shape ### Obtain latent z from action sequence - if is_training: + if self.vae and is_training: # project action sequence to embedding dim, and concat with a CLS token action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim) qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim) @@ -200,6 +203,7 @@ def build(args): action_dim=args.action_dim, num_queries=args.num_queries, camera_names=args.camera_names, + vae=args.vae, ) n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py index 13f4c199e..7928b3ab1 100644 --- a/lerobot/common/policies/act/policy.py +++ b/lerobot/common/policies/act/policy.py @@ -11,7 +11,6 @@ from lerobot.common.policies.act.detr_vae import build def build_act_model_and_optimizer(cfg): model = build(cfg) - model.cuda() param_dicts = [ {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]}, @@ -51,6 +50,8 @@ class ActionChunkingTransformerPolicy(nn.Module): self.kl_weight = self.cfg.kl_weight logging.info(f"KL Weight {self.kl_weight}") + self.to(self.device) + def update(self, replay_buffer, step): del step @@ -192,20 +193,25 @@ class ActionChunkingTransformerPolicy(nn.Module): normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) image = normalize(image) - is_train_mode = actions is not None - if is_train_mode: # training time + is_training = actions is not None + if is_training: # training time actions = actions[:, : self.model.num_queries] if is_pad is not None: is_pad = is_pad[:, : self.model.num_queries] a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad) - total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar) - loss_dict = {} + all_l1 = F.l1_loss(actions, a_hat, reduction="none") l1 = all_l1.mean() if is_pad is None else (all_l1 * ~is_pad.unsqueeze(-1)).mean() + + loss_dict = {} loss_dict["l1"] = l1 - loss_dict["kl"] = total_kld[0] - loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight + if self.cfg.vae: + total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar) + loss_dict["kl"] = total_kld[0] + loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight + else: + loss_dict["loss"] = loss_dict["l1"] return loss_dict else: action, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior diff --git a/lerobot/configs/env/aloha.yaml b/lerobot/configs/env/aloha.yaml index ceb8e87fe..df464c753 100644 --- a/lerobot/configs/env/aloha.yaml +++ b/lerobot/configs/env/aloha.yaml @@ -17,7 +17,7 @@ env: pixels_only: False image_size: [3, 480, 640] action_repeat: 1 - episode_length: 300 + episode_length: 400 fps: ${fps} policy: diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index 358ed83cf..a52c3f541 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -29,9 +29,10 @@ policy: kl_weight: 10 hidden_dim: 512 dim_feedforward: 3200 - enc_layers: 7 - dec_layers: 8 + enc_layers: 4 + dec_layers: 7 nheads: 8 + #camera_names: [top, front_close, left_pillar, right_pillar] camera_names: [top] position_embedding: sine masks: false @@ -39,6 +40,8 @@ policy: dropout: 0.1 pre_norm: false + vae: true + batch_size: 8 per_alpha: 0.6 diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index c9338dca6..8d0b2e880 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -38,27 +38,18 @@ def eval_policy( successes = [] threads = [] for i in tqdm.tqdm(range(num_episodes)): - tensordict = env.reset() - ep_frames = [] - if save_video or (return_first_video and i == 0): - def rendering_callback(env, td=None): + def render_frame(env): ep_frames.append(env.render()) # noqa: B023 - # render first frame before rollout - rendering_callback(env) - else: - rendering_callback = None + env.register_rendering_hook(render_frame) with torch.inference_mode(): rollout = env.rollout( max_steps=max_steps, policy=policy, - callback=rendering_callback, - auto_reset=False, - tensordict=tensordict, auto_cast_to_device=True, ) # print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()])) @@ -85,6 +76,8 @@ def eval_policy( if return_first_video and i == 0: first_video = stacked_frames.transpose(0, 3, 1, 2) + env.reset_rendering_hooks() + for thread in threads: thread.join() diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index be3bef8b1..02b1efae5 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -1,4 +1,5 @@ import logging +from pathlib import Path import hydra import numpy as np @@ -192,6 +193,8 @@ def train(cfg: dict, out_dir=None, job_name=None): num_episodes=cfg.eval_episodes, max_steps=cfg.env.episode_length // cfg.n_action_steps, return_first_video=True, + video_dir=Path(out_dir) / "eval", + save_video=True, ) log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline) if cfg.wandb.enable: diff --git a/sbatch.sh b/sbatch.sh index da52c472a..cb5b285af 100644 --- a/sbatch.sh +++ b/sbatch.sh @@ -17,6 +17,7 @@ apptainer exec --nv \ ~/apptainer/nvidia_cuda:12.2.2-devel-ubuntu22.04.sif $SHELL source ~/.bashrc -conda activate fowm +#conda activate fowm +conda activate lerobot srun $CMD diff --git a/tests/test_datasets.py b/tests/test_datasets.py index e63ae2c1e..71c14951b 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -12,10 +12,10 @@ from .utils import init_config # ("simxarm", "lift"), ("pusht", "pusht"), # TODO(aliberts): add aloha when dataset is available on hub - # ("aloha", "sim_insertion_human"), - # ("aloha", "sim_insertion_scripted"), - # ("aloha", "sim_transfer_cube_human"), - # ("aloha", "sim_transfer_cube_scripted"), + ("aloha", "sim_insertion_human"), + ("aloha", "sim_insertion_scripted"), + ("aloha", "sim_transfer_cube_human"), + ("aloha", "sim_transfer_cube_scripted"), ], ) def test_factory(env_name, dataset_id):