From d585c73f9f73ed3313f9e12969e60038389e6f59 Mon Sep 17 00:00:00 2001 From: Remi Date: Fri, 31 May 2024 15:31:02 +0200 Subject: [PATCH] Add real-world support for ACT on Aloha/Aloha2 (#228) Co-authored-by: Alexander Soare --- .gitignore | 1 - lerobot/__init__.py | 24 ++ lerobot/common/datasets/factory.py | 18 +- .../common/policies/act/configuration_act.py | 25 ++- lerobot/common/policies/act/modeling_act.py | 61 +++-- .../diffusion/configuration_diffusion.py | 23 +- .../policies/tdmpc/configuration_tdmpc.py | 9 + lerobot/configs/env/dora_aloha_real.yaml | 13 ++ lerobot/configs/policy/act_real.yaml | 115 ++++++++++ lerobot/configs/policy/act_real_no_state.yaml | 111 ++++++++++ poetry.lock | 209 +++++++++++------- pyproject.toml | 2 + .../actions.safetensors | 3 + .../grad_stats.safetensors | 3 + .../output_dict.safetensors | 3 + .../param_stats.safetensors | 3 + .../actions.safetensors | 3 + .../grad_stats.safetensors | 3 + .../output_dict.safetensors | 3 + .../param_stats.safetensors | 3 + ...ensor.py => save_policy_to_safetensors.py} | 19 +- tests/test_policies.py | 11 +- 22 files changed, 525 insertions(+), 140 deletions(-) create mode 100644 lerobot/configs/env/dora_aloha_real.yaml create mode 100644 lerobot/configs/policy/act_real.yaml create mode 100644 lerobot/configs/policy/act_real_no_state.yaml create mode 100644 tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/actions.safetensors create mode 100644 tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/grad_stats.safetensors create mode 100644 tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/output_dict.safetensors create mode 100644 tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/param_stats.safetensors create mode 100644 tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/actions.safetensors create mode 100644 tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/grad_stats.safetensors create mode 100644 tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/output_dict.safetensors create mode 100644 tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/param_stats.safetensors rename tests/scripts/{save_policy_to_safetensor.py => save_policy_to_safetensors.py} (91%) diff --git a/.gitignore b/.gitignore index 5b73b9ad..4ccf404d 100644 --- a/.gitignore +++ b/.gitignore @@ -121,7 +121,6 @@ celerybeat.pid # Environments .env .venv -env/ venv/ ENV/ env.bak/ diff --git a/lerobot/__init__.py b/lerobot/__init__.py index e0234f29..a5a90fb4 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -45,6 +45,9 @@ import itertools from lerobot.__version__ import __version__ # noqa: F401 +# TODO(rcadene): Improve policies and envs. As of now, an item in `available_policies` +# refers to a yaml file AND a modeling name. Same for `available_envs` which refers to +# a yaml file AND a environment name. The difference should be more obvious. available_tasks_per_env = { "aloha": [ "AlohaInsertion-v0", @@ -52,6 +55,7 @@ available_tasks_per_env = { ], "pusht": ["PushT-v0"], "xarm": ["XarmLift-v0"], + "dora_aloha_real": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"], } available_envs = list(available_tasks_per_env.keys()) @@ -77,6 +81,23 @@ available_datasets_per_env = { "lerobot/xarm_push_medium_image", "lerobot/xarm_push_medium_replay_image", ], + "dora_aloha_real": [ + "lerobot/aloha_static_battery", + "lerobot/aloha_static_candy", + "lerobot/aloha_static_coffee", + "lerobot/aloha_static_coffee_new", + "lerobot/aloha_static_cups_open", + "lerobot/aloha_static_fork_pick_up", + "lerobot/aloha_static_pingpong_test", + "lerobot/aloha_static_pro_pencil", + "lerobot/aloha_static_screw_driver", + "lerobot/aloha_static_tape", + "lerobot/aloha_static_thread_velcro", + "lerobot/aloha_static_towel", + "lerobot/aloha_static_vinh_cup", + "lerobot/aloha_static_vinh_cup_left", + "lerobot/aloha_static_ziploc_slide", + ], } available_real_world_datasets = [ @@ -108,16 +129,19 @@ available_datasets = list( itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets) ) +# lists all available policies from `lerobot/common/policies` by their class attribute: `name`. available_policies = [ "act", "diffusion", "tdmpc", ] +# keys and values refer to yaml files available_policies_per_env = { "aloha": ["act"], "pusht": ["diffusion"], "xarm": ["tdmpc"], + "dora_aloha_real": ["act_real"], } env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks] diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index b48a9211..4732f577 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -55,11 +55,19 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData "strings to load multiple datasets." ) - if isinstance(cfg.dataset_repo_id, str) and cfg.env.name not in cfg.dataset_repo_id: - logging.warning( - f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your " - f"environment ({cfg.env.name=})." - ) + # A soft check to warn if the environment matches the dataset. Don't check if we are using a real world env (dora). + if cfg.env.name != "dora": + if isinstance(cfg.dataset_repo_id, str): + dataset_repo_ids = [cfg.dataset_repo_id] # single dataset + else: + dataset_repo_ids = cfg.dataset_repo_id # multiple datasets + + for dataset_repo_id in dataset_repo_ids: + if cfg.env.name not in dataset_repo_id: + logging.warning( + f"There might be a mismatch between your training dataset ({dataset_repo_id=}) and your " + f"environment ({cfg.env.name=})." + ) resolve_delta_timestamps(cfg) diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 95374f4d..a4b0b7d2 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -25,6 +25,13 @@ class ACTConfig: The parameters you will most likely need to change are the ones which depend on the environment / sensors. Those are: `input_shapes` and 'output_shapes`. + Notes on the inputs and outputs: + - At least one key starting with "observation.image is required as an input. + - If there are multiple keys beginning with "observation.images." they are treated as multiple camera + views. Right now we only support all images having the same shape. + - May optionally work without an "observation.state" key for the proprioceptive robot state. + - "action" is required as an output key. + Args: n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the current step and additional steps going back). @@ -33,15 +40,15 @@ class ACTConfig: This should be no greater than the chunk size. For example, if the chunk size size 100, you may set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the environment, and throws the other 50 out. - input_shapes: A dictionary defining the shapes of the input data for the policy. - The key represents the input data name, and the value is a list indicating the dimensions - of the corresponding data. For example, "observation.images.top" refers to an input from the - "top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution. - Importantly, shapes doesn't include batch dimension or temporal dimension. - output_shapes: A dictionary defining the shapes of the output data for the policy. - The key represents the output data name, and the value is a list indicating the dimensions - of the corresponding data. For example, "action" refers to an output shape of [14], indicating - 14-dimensional actions. Importantly, shapes doesn't include batch dimension or temporal dimension. + input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents + the input data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], + indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't + include batch dimension or temporal dimension. + output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents + the output data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. + Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), and the value specifies the normalization mode to apply. The two available modes are "mean_std" which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index eafe677b..bef59bec 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -200,25 +200,29 @@ class ACT(nn.Module): self.config = config # BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence]. # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). + self.use_input_state = "observation.state" in config.input_shapes if self.config.use_vae: self.vae_encoder = ACTEncoder(config) self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model) # Projection layer for joint-space configuration to hidden dimension. - self.vae_encoder_robot_state_input_proj = nn.Linear( - config.input_shapes["observation.state"][0], config.dim_model - ) + if self.use_input_state: + self.vae_encoder_robot_state_input_proj = nn.Linear( + config.input_shapes["observation.state"][0], config.dim_model + ) # Projection layer for action (joint-space target) to hidden dimension. self.vae_encoder_action_input_proj = nn.Linear( - config.input_shapes["observation.state"][0], config.dim_model + config.output_shapes["action"][0], config.dim_model ) - self.latent_dim = config.latent_dim # Projection layer from the VAE encoder's output to the latent distribution's parameter space. - self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2) + self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2) # Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch # dimension. + num_input_token_encoder = 1 + config.chunk_size + if self.use_input_state: + num_input_token_encoder += 1 self.register_buffer( "vae_encoder_pos_enc", - create_sinusoidal_pos_embedding(1 + 1 + config.chunk_size, config.dim_model).unsqueeze(0), + create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0), ) # Backbone for image feature extraction. @@ -238,15 +242,17 @@ class ACT(nn.Module): # Transformer encoder input projections. The tokens will be structured like # [latent, robot_state, image_feature_map_pixels]. - self.encoder_robot_state_input_proj = nn.Linear( - config.input_shapes["observation.state"][0], config.dim_model - ) - self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model) + if self.use_input_state: + self.encoder_robot_state_input_proj = nn.Linear( + config.input_shapes["observation.state"][0], config.dim_model + ) + self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model) self.encoder_img_feat_input_proj = nn.Conv2d( backbone_model.fc.in_features, config.dim_model, kernel_size=1 ) # Transformer encoder positional embeddings. - self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, config.dim_model) + num_input_token_decoder = 2 if self.use_input_state else 1 + self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model) self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2) # Transformer decoder. @@ -285,7 +291,7 @@ class ACT(nn.Module): "action" in batch ), "actions must be provided when using the variational objective in training mode." - batch_size = batch["observation.state"].shape[0] + batch_size = batch["observation.images"].shape[0] # Prepare the latent for input to the transformer encoder. if self.config.use_vae and "action" in batch: @@ -293,11 +299,16 @@ class ACT(nn.Module): cls_embed = einops.repeat( self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size ) # (B, 1, D) - robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze( - 1 - ) # (B, 1, D) + if self.use_input_state: + robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]) + robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D) action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) - vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D) + + if self.use_input_state: + vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D) + else: + vae_encoder_input = [cls_embed, action_embed] + vae_encoder_input = torch.cat(vae_encoder_input, axis=1) # Prepare fixed positional embedding. # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case. @@ -308,16 +319,17 @@ class ACT(nn.Module): vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2) )[0] # select the class token, with shape (B, D) latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out) - mu = latent_pdf_params[:, : self.latent_dim] + mu = latent_pdf_params[:, : self.config.latent_dim] # This is 2log(sigma). Done this way to match the original implementation. - log_sigma_x2 = latent_pdf_params[:, self.latent_dim :] + log_sigma_x2 = latent_pdf_params[:, self.config.latent_dim :] # Sample the latent with the reparameterization trick. latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu) else: # When not using the VAE encoder, we set the latent to be all zeros. mu = log_sigma_x2 = None - latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to( + # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer + latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to( batch["observation.state"].device ) @@ -326,8 +338,10 @@ class ACT(nn.Module): all_cam_features = [] all_cam_pos_embeds = [] images = batch["observation.images"] + for cam_index in range(images.shape[-4]): cam_features = self.backbone(images[:, cam_index])["feature_map"] + # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) all_cam_features.append(cam_features) @@ -337,13 +351,15 @@ class ACT(nn.Module): cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1) # Get positional embeddings for robot state and latent. - robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C) + if self.use_input_state: + robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C) latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C) # Stack encoder input and positional embeddings moving to (S, B, C). + encoder_in_feats = [latent_embed, robot_state_embed] if self.use_input_state else [latent_embed] encoder_in = torch.cat( [ - torch.stack([latent_embed, robot_state_embed], axis=0), + torch.stack(encoder_in_feats, axis=0), einops.rearrange(encoder_in, "b c h w -> (h w) b c"), ] ) @@ -357,6 +373,7 @@ class ACT(nn.Module): # Forward pass through the transformer modules. encoder_out = self.encoder(encoder_in, pos_embed=pos_embed) + # TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer decoder_in = torch.zeros( (self.config.chunk_size, batch_size, self.config.dim_model), dtype=pos_embed.dtype, diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index 81ff5de7..59ed1656 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -26,21 +26,26 @@ class DiffusionConfig: The parameters you will most likely need to change are the ones which depend on the environment / sensors. Those are: `input_shapes` and `output_shapes`. + Notes on the inputs and outputs: + - "observation.state" is required as an input key. + - A key starting with "observation.image is required as an input. + - "action" is required as an output key. + Args: n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the current step and additional steps going back). horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`. n_action_steps: The number of action steps to run in the environment for one invocation of the policy. See `DiffusionPolicy.select_action` for more details. - input_shapes: A dictionary defining the shapes of the input data for the policy. - The key represents the input data name, and the value is a list indicating the dimensions - of the corresponding data. For example, "observation.image" refers to an input from - a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution. - Importantly, shapes doesnt include batch dimension or temporal dimension. - output_shapes: A dictionary defining the shapes of the output data for the policy. - The key represents the output data name, and the value is a list indicating the dimensions - of the corresponding data. For example, "action" refers to an output shape of [14], indicating - 14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension. + input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents + the input data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], + indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't + include batch dimension or temporal dimension. + output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents + the output data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. + Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), and the value specifies the normalization mode to apply. The two available modes are "mean_std" which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a diff --git a/lerobot/common/policies/tdmpc/configuration_tdmpc.py b/lerobot/common/policies/tdmpc/configuration_tdmpc.py index cf76fb08..49485c39 100644 --- a/lerobot/common/policies/tdmpc/configuration_tdmpc.py +++ b/lerobot/common/policies/tdmpc/configuration_tdmpc.py @@ -31,6 +31,15 @@ class TDMPCConfig: n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google action repeats in Q-learning or ask your favorite chatbot) horizon: Horizon for model predictive control. + input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents + the input data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], + indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't + include batch dimension or temporal dimension. + output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents + the output data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. + Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), and the value specifies the normalization mode to apply. The two available modes are "mean_std" which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a diff --git a/lerobot/configs/env/dora_aloha_real.yaml b/lerobot/configs/env/dora_aloha_real.yaml new file mode 100644 index 00000000..088781d4 --- /dev/null +++ b/lerobot/configs/env/dora_aloha_real.yaml @@ -0,0 +1,13 @@ +# @package _global_ + +fps: 30 + +env: + name: dora + task: DoraAloha-v0 + state_dim: 14 + action_dim: 14 + fps: ${fps} + episode_length: 400 + gym: + fps: ${fps} diff --git a/lerobot/configs/policy/act_real.yaml b/lerobot/configs/policy/act_real.yaml new file mode 100644 index 00000000..b4942615 --- /dev/null +++ b/lerobot/configs/policy/act_real.yaml @@ -0,0 +1,115 @@ +# @package _global_ + +# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets. +# Compared to `act.yaml`, it contains 4 cameras (i.e. cam_right_wrist, cam_left_wrist, images, +# cam_low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used +# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation. +# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot). +# Look at its README for more information on how to evaluate a checkpoint in the real-world. +# +# Example of usage for training: +# ```bash +# python lerobot/scripts/train.py \ +# policy=act_real \ +# env=dora_aloha_real +# ``` + +seed: 1000 +dataset_repo_id: lerobot/aloha_static_vinh_cup + +override_dataset_stats: + observation.images.cam_right_wrist: + # stats from imagenet, since we use a pretrained vision model + mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) + std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) + observation.images.cam_left_wrist: + # stats from imagenet, since we use a pretrained vision model + mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) + std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) + observation.images.cam_high: + # stats from imagenet, since we use a pretrained vision model + mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) + std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) + observation.images.cam_low: + # stats from imagenet, since we use a pretrained vision model + mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) + std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) + +training: + offline_steps: 80000 + online_steps: 0 + eval_freq: -1 + save_freq: 10000 + log_freq: 100 + save_checkpoint: true + + batch_size: 8 + lr: 1e-5 + lr_backbone: 1e-5 + weight_decay: 1e-4 + grad_clip_norm: 10 + online_steps_between_rollouts: 1 + + delta_timestamps: + action: "[i / ${fps} for i in range(${policy.chunk_size})]" + +eval: + n_episodes: 50 + batch_size: 50 + +# See `configuration_act.py` for more details. +policy: + name: act + + # Input / output structure. + n_obs_steps: 1 + chunk_size: 100 # chunk_size + n_action_steps: 100 + + input_shapes: + # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? + observation.images.cam_right_wrist: [3, 480, 640] + observation.images.cam_left_wrist: [3, 480, 640] + observation.images.cam_high: [3, 480, 640] + observation.images.cam_low: [3, 480, 640] + observation.state: ["${env.state_dim}"] + output_shapes: + action: ["${env.action_dim}"] + + # Normalization / Unnormalization + input_normalization_modes: + observation.images.cam_right_wrist: mean_std + observation.images.cam_left_wrist: mean_std + observation.images.cam_high: mean_std + observation.images.cam_low: mean_std + observation.state: mean_std + output_normalization_modes: + action: mean_std + + # Architecture. + # Vision backbone. + vision_backbone: resnet18 + pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1 + replace_final_stride_with_dilation: false + # Transformer layers. + pre_norm: false + dim_model: 512 + n_heads: 8 + dim_feedforward: 3200 + feedforward_activation: relu + n_encoder_layers: 4 + # Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code + # that means only the first layer is used. Here we match the original implementation by setting this to 1. + # See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521. + n_decoder_layers: 1 + # VAE. + use_vae: true + latent_dim: 32 + n_vae_encoder_layers: 4 + + # Inference. + temporal_ensemble_momentum: null + + # Training and loss computation. + dropout: 0.1 + kl_weight: 10.0 diff --git a/lerobot/configs/policy/act_real_no_state.yaml b/lerobot/configs/policy/act_real_no_state.yaml new file mode 100644 index 00000000..a8b1c9b6 --- /dev/null +++ b/lerobot/configs/policy/act_real_no_state.yaml @@ -0,0 +1,111 @@ +# @package _global_ + +# Use `act_real_no_state.yaml` to train on real-world Aloha/Aloha2 datasets when cameras are moving (e.g. wrist cameras) +# Compared to `act_real.yaml`, it is camera only and does not use the state as input which is vector of robot joint positions. +# We validated experimentaly that not using state reaches better success rate. Our hypothesis is that `act_real.yaml` might +# overfits to the state, because the images are more complex to learn from since they are moving. +# +# Example of usage for training: +# ```bash +# python lerobot/scripts/train.py \ +# policy=act_real_no_state \ +# env=dora_aloha_real +# ``` + +seed: 1000 +dataset_repo_id: lerobot/aloha_static_vinh_cup + +override_dataset_stats: + observation.images.cam_right_wrist: + # stats from imagenet, since we use a pretrained vision model + mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) + std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) + observation.images.cam_left_wrist: + # stats from imagenet, since we use a pretrained vision model + mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) + std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) + observation.images.cam_high: + # stats from imagenet, since we use a pretrained vision model + mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) + std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) + observation.images.cam_low: + # stats from imagenet, since we use a pretrained vision model + mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) + std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) + +training: + offline_steps: 80000 + online_steps: 0 + eval_freq: -1 + save_freq: 10000 + log_freq: 100 + save_checkpoint: true + + batch_size: 8 + lr: 1e-5 + lr_backbone: 1e-5 + weight_decay: 1e-4 + grad_clip_norm: 10 + online_steps_between_rollouts: 1 + + delta_timestamps: + action: "[i / ${fps} for i in range(${policy.chunk_size})]" + +eval: + n_episodes: 50 + batch_size: 50 + +# See `configuration_act.py` for more details. +policy: + name: act + + # Input / output structure. + n_obs_steps: 1 + chunk_size: 100 # chunk_size + n_action_steps: 100 + + input_shapes: + # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? + observation.images.cam_right_wrist: [3, 480, 640] + observation.images.cam_left_wrist: [3, 480, 640] + observation.images.cam_high: [3, 480, 640] + observation.images.cam_low: [3, 480, 640] + output_shapes: + action: ["${env.action_dim}"] + + # Normalization / Unnormalization + input_normalization_modes: + observation.images.cam_right_wrist: mean_std + observation.images.cam_left_wrist: mean_std + observation.images.cam_high: mean_std + observation.images.cam_low: mean_std + output_normalization_modes: + action: mean_std + + # Architecture. + # Vision backbone. + vision_backbone: resnet18 + pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1 + replace_final_stride_with_dilation: false + # Transformer layers. + pre_norm: false + dim_model: 512 + n_heads: 8 + dim_feedforward: 3200 + feedforward_activation: relu + n_encoder_layers: 4 + # Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code + # that means only the first layer is used. Here we match the original implementation by setting this to 1. + # See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521. + n_decoder_layers: 1 + # VAE. + use_vae: true + latent_dim: 32 + n_vae_encoder_layers: 4 + + # Inference. + temporal_ensemble_momentum: null + + # Training and loss computation. + dropout: 0.1 + kl_weight: 10.0 diff --git a/poetry.lock b/poetry.lock index 3a04e3d1..e9d6e848 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -444,63 +444,63 @@ files = [ [[package]] name = "coverage" -version = "7.5.1" +version = "7.5.3" description = "Code coverage measurement for Python" optional = true python-versions = ">=3.8" files = [ - {file = "coverage-7.5.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0884920835a033b78d1c73b6d3bbcda8161a900f38a488829a83982925f6c2e"}, - {file = "coverage-7.5.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:39afcd3d4339329c5f58de48a52f6e4e50f6578dd6099961cf22228feb25f38f"}, - {file = "coverage-7.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a7b0ceee8147444347da6a66be737c9d78f3353b0681715b668b72e79203e4a"}, - {file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a9ca3f2fae0088c3c71d743d85404cec8df9be818a005ea065495bedc33da35"}, - {file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd215c0c7d7aab005221608a3c2b46f58c0285a819565887ee0b718c052aa4e"}, - {file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4bf0655ab60d754491004a5efd7f9cccefcc1081a74c9ef2da4735d6ee4a6223"}, - {file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:61c4bf1ba021817de12b813338c9be9f0ad5b1e781b9b340a6d29fc13e7c1b5e"}, - {file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:db66fc317a046556a96b453a58eced5024af4582a8dbdc0c23ca4dbc0d5b3146"}, - {file = "coverage-7.5.1-cp310-cp310-win32.whl", hash = "sha256:b016ea6b959d3b9556cb401c55a37547135a587db0115635a443b2ce8f1c7228"}, - {file = "coverage-7.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:df4e745a81c110e7446b1cc8131bf986157770fa405fe90e15e850aaf7619bc8"}, - {file = "coverage-7.5.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:796a79f63eca8814ca3317a1ea443645c9ff0d18b188de470ed7ccd45ae79428"}, - {file = "coverage-7.5.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4fc84a37bfd98db31beae3c2748811a3fa72bf2007ff7902f68746d9757f3746"}, - {file = "coverage-7.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6175d1a0559986c6ee3f7fccfc4a90ecd12ba0a383dcc2da30c2b9918d67d8a3"}, - {file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fc81d5878cd6274ce971e0a3a18a8803c3fe25457165314271cf78e3aae3aa2"}, - {file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:556cf1a7cbc8028cb60e1ff0be806be2eded2daf8129b8811c63e2b9a6c43bca"}, - {file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:9981706d300c18d8b220995ad22627647be11a4276721c10911e0e9fa44c83e8"}, - {file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d7fed867ee50edf1a0b4a11e8e5d0895150e572af1cd6d315d557758bfa9c057"}, - {file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ef48e2707fb320c8f139424a596f5b69955a85b178f15af261bab871873bb987"}, - {file = "coverage-7.5.1-cp311-cp311-win32.whl", hash = "sha256:9314d5678dcc665330df5b69c1e726a0e49b27df0461c08ca12674bcc19ef136"}, - {file = "coverage-7.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:5fa567e99765fe98f4e7d7394ce623e794d7cabb170f2ca2ac5a4174437e90dd"}, - {file = "coverage-7.5.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b6cf3764c030e5338e7f61f95bd21147963cf6aa16e09d2f74f1fa52013c1206"}, - {file = "coverage-7.5.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ec92012fefebee89a6b9c79bc39051a6cb3891d562b9270ab10ecfdadbc0c34"}, - {file = "coverage-7.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16db7f26000a07efcf6aea00316f6ac57e7d9a96501e990a36f40c965ec7a95d"}, - {file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:beccf7b8a10b09c4ae543582c1319c6df47d78fd732f854ac68d518ee1fb97fa"}, - {file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8748731ad392d736cc9ccac03c9845b13bb07d020a33423fa5b3a36521ac6e4e"}, - {file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7352b9161b33fd0b643ccd1f21f3a3908daaddf414f1c6cb9d3a2fd618bf2572"}, - {file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:7a588d39e0925f6a2bff87154752481273cdb1736270642aeb3635cb9b4cad07"}, - {file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:68f962d9b72ce69ea8621f57551b2fa9c70509af757ee3b8105d4f51b92b41a7"}, - {file = "coverage-7.5.1-cp312-cp312-win32.whl", hash = "sha256:f152cbf5b88aaeb836127d920dd0f5e7edff5a66f10c079157306c4343d86c19"}, - {file = "coverage-7.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:5a5740d1fb60ddf268a3811bcd353de34eb56dc24e8f52a7f05ee513b2d4f596"}, - {file = "coverage-7.5.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e2213def81a50519d7cc56ed643c9e93e0247f5bbe0d1247d15fa520814a7cd7"}, - {file = "coverage-7.5.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5037f8fcc2a95b1f0e80585bd9d1ec31068a9bcb157d9750a172836e98bc7a90"}, - {file = "coverage-7.5.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3721c2c9e4c4953a41a26c14f4cef64330392a6d2d675c8b1db3b645e31f0e"}, - {file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca498687ca46a62ae590253fba634a1fe9836bc56f626852fb2720f334c9e4e5"}, - {file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0cdcbc320b14c3e5877ee79e649677cb7d89ef588852e9583e6b24c2e5072661"}, - {file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:57e0204b5b745594e5bc14b9b50006da722827f0b8c776949f1135677e88d0b8"}, - {file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:8fe7502616b67b234482c3ce276ff26f39ffe88adca2acf0261df4b8454668b4"}, - {file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:9e78295f4144f9dacfed4f92935fbe1780021247c2fabf73a819b17f0ccfff8d"}, - {file = "coverage-7.5.1-cp38-cp38-win32.whl", hash = "sha256:1434e088b41594baa71188a17533083eabf5609e8e72f16ce8c186001e6b8c41"}, - {file = "coverage-7.5.1-cp38-cp38-win_amd64.whl", hash = "sha256:0646599e9b139988b63704d704af8e8df7fa4cbc4a1f33df69d97f36cb0a38de"}, - {file = "coverage-7.5.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4cc37def103a2725bc672f84bd939a6fe4522310503207aae4d56351644682f1"}, - {file = "coverage-7.5.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fc0b4d8bfeabd25ea75e94632f5b6e047eef8adaed0c2161ada1e922e7f7cece"}, - {file = "coverage-7.5.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d0a0f5e06881ecedfe6f3dd2f56dcb057b6dbeb3327fd32d4b12854df36bf26"}, - {file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9735317685ba6ec7e3754798c8871c2f49aa5e687cc794a0b1d284b2389d1bd5"}, - {file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d21918e9ef11edf36764b93101e2ae8cc82aa5efdc7c5a4e9c6c35a48496d601"}, - {file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c3e757949f268364b96ca894b4c342b41dc6f8f8b66c37878aacef5930db61be"}, - {file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:79afb6197e2f7f60c4824dd4b2d4c2ec5801ceb6ba9ce5d2c3080e5660d51a4f"}, - {file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d1d0d98d95dd18fe29dc66808e1accf59f037d5716f86a501fc0256455219668"}, - {file = "coverage-7.5.1-cp39-cp39-win32.whl", hash = "sha256:1cc0fe9b0b3a8364093c53b0b4c0c2dd4bb23acbec4c9240b5f284095ccf7981"}, - {file = "coverage-7.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:dde0070c40ea8bb3641e811c1cfbf18e265d024deff6de52c5950677a8fb1e0f"}, - {file = "coverage-7.5.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:6537e7c10cc47c595828b8a8be04c72144725c383c4702703ff4e42e44577312"}, - {file = "coverage-7.5.1.tar.gz", hash = "sha256:54de9ef3a9da981f7af93eafde4ede199e0846cd819eb27c88e2b712aae9708c"}, + {file = "coverage-7.5.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a6519d917abb15e12380406d721e37613e2a67d166f9fb7e5a8ce0375744cd45"}, + {file = "coverage-7.5.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:aea7da970f1feccf48be7335f8b2ca64baf9b589d79e05b9397a06696ce1a1ec"}, + {file = "coverage-7.5.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:923b7b1c717bd0f0f92d862d1ff51d9b2b55dbbd133e05680204465f454bb286"}, + {file = "coverage-7.5.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62bda40da1e68898186f274f832ef3e759ce929da9a9fd9fcf265956de269dbc"}, + {file = "coverage-7.5.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8b7339180d00de83e930358223c617cc343dd08e1aa5ec7b06c3a121aec4e1d"}, + {file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:25a5caf742c6195e08002d3b6c2dd6947e50efc5fc2c2205f61ecb47592d2d83"}, + {file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:05ac5f60faa0c704c0f7e6a5cbfd6f02101ed05e0aee4d2822637a9e672c998d"}, + {file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:239a4e75e09c2b12ea478d28815acf83334d32e722e7433471fbf641c606344c"}, + {file = "coverage-7.5.3-cp310-cp310-win32.whl", hash = "sha256:a5812840d1d00eafae6585aba38021f90a705a25b8216ec7f66aebe5b619fb84"}, + {file = "coverage-7.5.3-cp310-cp310-win_amd64.whl", hash = "sha256:33ca90a0eb29225f195e30684ba4a6db05dbef03c2ccd50b9077714c48153cac"}, + {file = "coverage-7.5.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f81bc26d609bf0fbc622c7122ba6307993c83c795d2d6f6f6fd8c000a770d974"}, + {file = "coverage-7.5.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7cec2af81f9e7569280822be68bd57e51b86d42e59ea30d10ebdbb22d2cb7232"}, + {file = "coverage-7.5.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:55f689f846661e3f26efa535071775d0483388a1ccfab899df72924805e9e7cd"}, + {file = "coverage-7.5.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50084d3516aa263791198913a17354bd1dc627d3c1639209640b9cac3fef5807"}, + {file = "coverage-7.5.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:341dd8f61c26337c37988345ca5c8ccabeff33093a26953a1ac72e7d0103c4fb"}, + {file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ab0b028165eea880af12f66086694768f2c3139b2c31ad5e032c8edbafca6ffc"}, + {file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:5bc5a8c87714b0c67cfeb4c7caa82b2d71e8864d1a46aa990b5588fa953673b8"}, + {file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:38a3b98dae8a7c9057bd91fbf3415c05e700a5114c5f1b5b0ea5f8f429ba6614"}, + {file = "coverage-7.5.3-cp311-cp311-win32.whl", hash = "sha256:fcf7d1d6f5da887ca04302db8e0e0cf56ce9a5e05f202720e49b3e8157ddb9a9"}, + {file = "coverage-7.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:8c836309931839cca658a78a888dab9676b5c988d0dd34ca247f5f3e679f4e7a"}, + {file = "coverage-7.5.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:296a7d9bbc598e8744c00f7a6cecf1da9b30ae9ad51c566291ff1314e6cbbed8"}, + {file = "coverage-7.5.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:34d6d21d8795a97b14d503dcaf74226ae51eb1f2bd41015d3ef332a24d0a17b3"}, + {file = "coverage-7.5.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e317953bb4c074c06c798a11dbdd2cf9979dbcaa8ccc0fa4701d80042d4ebf1"}, + {file = "coverage-7.5.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:705f3d7c2b098c40f5b81790a5fedb274113373d4d1a69e65f8b68b0cc26f6db"}, + {file = "coverage-7.5.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1196e13c45e327d6cd0b6e471530a1882f1017eb83c6229fc613cd1a11b53cd"}, + {file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:015eddc5ccd5364dcb902eaecf9515636806fa1e0d5bef5769d06d0f31b54523"}, + {file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:fd27d8b49e574e50caa65196d908f80e4dff64d7e592d0c59788b45aad7e8b35"}, + {file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:33fc65740267222fc02975c061eb7167185fef4cc8f2770267ee8bf7d6a42f84"}, + {file = "coverage-7.5.3-cp312-cp312-win32.whl", hash = "sha256:7b2a19e13dfb5c8e145c7a6ea959485ee8e2204699903c88c7d25283584bfc08"}, + {file = "coverage-7.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:0bbddc54bbacfc09b3edaec644d4ac90c08ee8ed4844b0f86227dcda2d428fcb"}, + {file = "coverage-7.5.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f78300789a708ac1f17e134593f577407d52d0417305435b134805c4fb135adb"}, + {file = "coverage-7.5.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b368e1aee1b9b75757942d44d7598dcd22a9dbb126affcbba82d15917f0cc155"}, + {file = "coverage-7.5.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f836c174c3a7f639bded48ec913f348c4761cbf49de4a20a956d3431a7c9cb24"}, + {file = "coverage-7.5.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:244f509f126dc71369393ce5fea17c0592c40ee44e607b6d855e9c4ac57aac98"}, + {file = "coverage-7.5.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4c2872b3c91f9baa836147ca33650dc5c172e9273c808c3c3199c75490e709d"}, + {file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:dd4b3355b01273a56b20c219e74e7549e14370b31a4ffe42706a8cda91f19f6d"}, + {file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:f542287b1489c7a860d43a7d8883e27ca62ab84ca53c965d11dac1d3a1fab7ce"}, + {file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:75e3f4e86804023e991096b29e147e635f5e2568f77883a1e6eed74512659ab0"}, + {file = "coverage-7.5.3-cp38-cp38-win32.whl", hash = "sha256:c59d2ad092dc0551d9f79d9d44d005c945ba95832a6798f98f9216ede3d5f485"}, + {file = "coverage-7.5.3-cp38-cp38-win_amd64.whl", hash = "sha256:fa21a04112c59ad54f69d80e376f7f9d0f5f9123ab87ecd18fbb9ec3a2beed56"}, + {file = "coverage-7.5.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f5102a92855d518b0996eb197772f5ac2a527c0ec617124ad5242a3af5e25f85"}, + {file = "coverage-7.5.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d1da0a2e3b37b745a2b2a678a4c796462cf753aebf94edcc87dcc6b8641eae31"}, + {file = "coverage-7.5.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8383a6c8cefba1b7cecc0149415046b6fc38836295bc4c84e820872eb5478b3d"}, + {file = "coverage-7.5.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9aad68c3f2566dfae84bf46295a79e79d904e1c21ccfc66de88cd446f8686341"}, + {file = "coverage-7.5.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e079c9ec772fedbade9d7ebc36202a1d9ef7291bc9b3a024ca395c4d52853d7"}, + {file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:bde997cac85fcac227b27d4fb2c7608a2c5f6558469b0eb704c5726ae49e1c52"}, + {file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:990fb20b32990b2ce2c5f974c3e738c9358b2735bc05075d50a6f36721b8f303"}, + {file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3d5a67f0da401e105753d474369ab034c7bae51a4c31c77d94030d59e41df5bd"}, + {file = "coverage-7.5.3-cp39-cp39-win32.whl", hash = "sha256:e08c470c2eb01977d221fd87495b44867a56d4d594f43739a8028f8646a51e0d"}, + {file = "coverage-7.5.3-cp39-cp39-win_amd64.whl", hash = "sha256:1d2a830ade66d3563bb61d1e3c77c8def97b30ed91e166c67d0632c018f380f0"}, + {file = "coverage-7.5.3-pp38.pp39.pp310-none-any.whl", hash = "sha256:3538d8fb1ee9bdd2e2692b3b18c22bb1c19ffbefd06880f5ac496e42d7bb3884"}, + {file = "coverage-7.5.3.tar.gz", hash = "sha256:04aefca5190d1dc7a53a4c1a5a7f8568811306d7a8ee231c42fb69215571944f"}, ] [package.dependencies] @@ -785,6 +785,26 @@ files = [ [package.dependencies] six = ">=1.4.0" +[[package]] +name = "dora-rs" +version = "0.3.4" +description = "`dora` goal is to be a low latency, composable, and distributed data flow." +optional = true +python-versions = "*" +files = [ + {file = "dora_rs-0.3.4-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:d1b738eea5a4966d731c26c6b6a0a50a491a24f7e9e335475f983cfc6f0da19e"}, + {file = "dora_rs-0.3.4-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:80b724871618c78a4e5863938fa66724176cc40352771087aebe1e62a8141157"}, + {file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a3919e157b47dc1dbc74c040a73087a4485f0d1bee99b6adcdbc36559400fe2"}, + {file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f7c95f6e5858fd651d6cd220e4f052e99db2944b9c37fb0b5402d60ac4b41a63"}, + {file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37d915fbbca282446235c98a9ca08389aa3ef3155d4e88c6c136326e9a830042"}, + {file = "dora_rs-0.3.4-cp37-abi3-win32.whl", hash = "sha256:c9f7f22f65c884ec9bee0245ce98d0c7fad25dec0f982e566f844b5e8e58818f"}, + {file = "dora_rs-0.3.4-cp37-abi3-win_amd64.whl", hash = "sha256:0a6a37f96a9f6e13b58b02a6ea75af192af5fbe4f456f6a67b1f239c3cee3276"}, + {file = "dora_rs-0.3.4.tar.gz", hash = "sha256:05c5d0db0d23d7c4669995ae34db11cd636dbf91f5705d832669bd04e7452903"}, +] + +[package.dependencies] +pyarrow = "*" + [[package]] name = "einops" version = "0.8.0" @@ -1066,6 +1086,27 @@ mujoco = ">=2.3.7,<3.0.0" dev = ["debugpy (>=1.8.1)", "pre-commit (>=3.7.0)"] test = ["pytest (>=8.1.0)", "pytest-cov (>=5.0.0)"] +[[package]] +name = "gym-dora" +version = "0.1.0" +description = "" +optional = true +python-versions = "^3.10" +files = [] +develop = false + +[package.dependencies] +dora-rs = ">=0.3.4" +gymnasium = ">=0.29.1" +pyarrow = ">=12.0.0" + +[package.source] +type = "git" +url = "https://github.com/dora-rs/dora-lerobot.git" +reference = "HEAD" +resolved_reference = "ed0c00a4fdc6ec856c9842551acd7dc7ee776f79" +subdirectory = "gym_dora" + [[package]] name = "gym-pusht" version = "0.1.4" @@ -1269,13 +1310,13 @@ files = [ [[package]] name = "huggingface-hub" -version = "0.23.1" +version = "0.23.2" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" optional = false python-versions = ">=3.8.0" files = [ - {file = "huggingface_hub-0.23.1-py3-none-any.whl", hash = "sha256:720a5bffd2b1b449deb793da8b0df7a9390a7e238534d5a08c9fbcdecb1dd3cb"}, - {file = "huggingface_hub-0.23.1.tar.gz", hash = "sha256:4f62dbf6ae94f400c6d3419485e52bce510591432a5248a65d0cb72e4d479eb4"}, + {file = "huggingface_hub-0.23.2-py3-none-any.whl", hash = "sha256:48727a16e704d409c4bb5913613308499664f22a99743435dc3a13b23c485827"}, + {file = "huggingface_hub-0.23.2.tar.gz", hash = "sha256:f6829b62d5fdecb452a76fdbec620cba4c1573655a8d710c1df71735fd9edbd2"}, ] [package.dependencies] @@ -2061,18 +2102,15 @@ test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] [[package]] name = "nodeenv" -version = "1.8.0" +version = "1.9.0" description = "Node.js virtual environment builder" optional = true -python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*" +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" files = [ - {file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"}, - {file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"}, + {file = "nodeenv-1.9.0-py2.py3-none-any.whl", hash = "sha256:508ecec98f9f3330b636d4448c0f1a56fc68017c68f1e7857ebc52acf0eb879a"}, + {file = "nodeenv-1.9.0.tar.gz", hash = "sha256:07f144e90dae547bf0d4ee8da0ee42664a42a04e02ed68e06324348dafe4bdb1"}, ] -[package.dependencies] -setuptools = "*" - [[package]] name = "numba" version = "0.59.1" @@ -2406,6 +2444,7 @@ optional = false python-versions = ">=3.9" files = [ {file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"}, + {file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"}, {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"}, {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"}, {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"}, @@ -2426,6 +2465,7 @@ files = [ {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"}, {file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"}, {file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"}, + {file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"}, {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"}, {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"}, {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"}, @@ -3188,13 +3228,13 @@ files = [ [[package]] name = "requests" -version = "2.32.2" +version = "2.32.3" description = "Python HTTP for Humans." optional = false python-versions = ">=3.8" files = [ - {file = "requests-2.32.2-py3-none-any.whl", hash = "sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c"}, - {file = "requests-2.32.2.tar.gz", hash = "sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289"}, + {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, + {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, ] [package.dependencies] @@ -3210,16 +3250,16 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "rerun-sdk" -version = "0.16.0" +version = "0.16.1" description = "The Rerun Logging SDK" optional = false python-versions = "<3.13,>=3.8" files = [ - {file = "rerun_sdk-0.16.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:1cc6dc66d089e296f945dc238301889efb61dd6d338b5d00f76981cf7aed0a74"}, - {file = "rerun_sdk-0.16.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:faf231897655e46eb975695df2b0ace07db362d697e697f9a3dff52f81c0dc5d"}, - {file = "rerun_sdk-0.16.0-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:860a6394380d3e9b9e48bf34423bd56dda54d5b0158d2ae0e433698659b34198"}, - {file = "rerun_sdk-0.16.0-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:5b8d1476f73a3ad1a5d3f21b61c633f3ab62aa80fa0b049f5ad10bf1227681ab"}, - {file = "rerun_sdk-0.16.0-cp38-abi3-win_amd64.whl", hash = "sha256:aff0051a263b8c3067243c0126d319845baf4fe640899f04aeef7daf151f35e4"}, + {file = "rerun_sdk-0.16.1-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:170c6976634008611753e10dfef8cdc395ce8180e634c169e7c61cef2f89a277"}, + {file = "rerun_sdk-0.16.1-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:c9a76eab7eb5559276737dad655200e9350df0837158dbc5a896970ab4201454"}, + {file = "rerun_sdk-0.16.1-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:4d6436752d57e8b8038489a0e7e37f0c760b088e96db5fb81667d3a376d63fea"}, + {file = "rerun_sdk-0.16.1-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:37b7b47948471873e84f224b16f417a94a91c7cbd6c72c68281eeff1ba414b8f"}, + {file = "rerun_sdk-0.16.1-cp38-abi3-win_amd64.whl", hash = "sha256:be88799c8afdf68eafa99e64e2e4f0a484e187e017a180219abbe6bb988acd4e"}, ] [package.dependencies] @@ -3696,17 +3736,17 @@ files = [ [[package]] name = "sympy" -version = "1.12" +version = "1.12.1" description = "Computer algebra system (CAS) in Python" optional = false python-versions = ">=3.8" files = [ - {file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"}, - {file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"}, + {file = "sympy-1.12.1-py3-none-any.whl", hash = "sha256:9b2cbc7f1a640289430e13d2a56f02f867a1da0190f2f99d8968c2f74da0e515"}, + {file = "sympy-1.12.1.tar.gz", hash = "sha256:2877b03f998cd8c08f07cd0de5b767119cd3ef40d09f41c30d722f6686b0fb88"}, ] [package.dependencies] -mpmath = ">=0.19" +mpmath = ">=1.1.0,<1.4.0" [[package]] name = "tbb" @@ -4220,13 +4260,13 @@ multidict = ">=4.0" [[package]] name = "zarr" -version = "2.18.1" +version = "2.18.2" description = "An implementation of chunked, compressed, N-dimensional arrays for Python" optional = false python-versions = ">=3.9" files = [ - {file = "zarr-2.18.1-py3-none-any.whl", hash = "sha256:a1770d194eec4ec0a41a01295a6f724e1c3471d704d3aca906d3b3a7f8830245"}, - {file = "zarr-2.18.1.tar.gz", hash = "sha256:28c360ed123e606c425a694a83300227a907cb86a995fc9eef620ecafbe5f92d"}, + {file = "zarr-2.18.2-py3-none-any.whl", hash = "sha256:a638754902f97efa99b406083fdc807a0e2ccf12a949117389d2a4ba9b05df38"}, + {file = "zarr-2.18.2.tar.gz", hash = "sha256:9bb393b8a0a38fb121dbb913b047d75db28de9890f6d644a217a73cf4ae74f47"}, ] [package.dependencies] @@ -4241,13 +4281,13 @@ jupyter = ["ipytree (>=0.2.2)", "ipywidgets (>=8.0.0)", "notebook"] [[package]] name = "zipp" -version = "3.18.2" +version = "3.19.0" description = "Backport of pathlib-compatible object wrapper for zip files" optional = false python-versions = ">=3.8" files = [ - {file = "zipp-3.18.2-py3-none-any.whl", hash = "sha256:dce197b859eb796242b0622af1b8beb0a722d52aa2f57133ead08edd5bf5374e"}, - {file = "zipp-3.18.2.tar.gz", hash = "sha256:6278d9ddbcfb1f1089a88fde84481528b07b0e10474e09dcfe53dad4069fa059"}, + {file = "zipp-3.19.0-py3-none-any.whl", hash = "sha256:96dc6ad62f1441bcaccef23b274ec471518daf4fbbc580341204936a5a3dddec"}, + {file = "zipp-3.19.0.tar.gz", hash = "sha256:952df858fb3164426c976d9338d3961e8e8b3758e2e059e0f754b8c4262625ee"}, ] [package.extras] @@ -4257,6 +4297,7 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more [extras] aloha = ["gym-aloha"] dev = ["debugpy", "pre-commit"] +dora = ["gym-dora"] pusht = ["gym-pusht"] test = ["pytest", "pytest-cov"] umi = ["imagecodecs"] @@ -4265,4 +4306,4 @@ xarm = ["gym-xarm"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "1ad6ef0f88f0056ab639e60e033e586f7460a9c5fc3676a477bbd47923f41cb6" +content-hash = "23ddb8dd774a4faf85d08a07dfdf19badb7c370120834b71df4afca254520771" diff --git a/pyproject.toml b/pyproject.toml index 1dd8b1d6..0c305218 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ h5py = ">=3.10.0" huggingface-hub = {extras = ["hf-transfer"], version = "^0.23.0"} gymnasium = ">=0.29.1" cmake = ">=3.29.0.1" +gym-dora = { git = "https://github.com/dora-rs/dora-lerobot.git", subdirectory = "gym_dora", optional = true } gym-pusht = { version = ">=0.1.3", optional = true} gym-xarm = { version = ">=0.1.1", optional = true} gym-aloha = { version = ">=0.1.1", optional = true} @@ -62,6 +63,7 @@ deepdiff = ">=7.0.1" [tool.poetry.extras] +dora = ["gym-dora"] pusht = ["gym-pusht"] xarm = ["gym-xarm"] aloha = ["gym-aloha"] diff --git a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/actions.safetensors b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/actions.safetensors new file mode 100644 index 00000000..2373f1ee --- /dev/null +++ b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/actions.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2fff6294b94cf42d4dd1249dcc5c3b0269d6d9c697f894e61b867d7ab81a94e4 +size 5104 diff --git a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/grad_stats.safetensors new file mode 100644 index 00000000..de40a20e --- /dev/null +++ b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/grad_stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4aa23e51607604a18b70fa42edbbe1af34f119d985628fc27cc1bbb0efbc8901 +size 31688 diff --git a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/output_dict.safetensors b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/output_dict.safetensors new file mode 100644 index 00000000..8602cc56 --- /dev/null +++ b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/output_dict.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6fd368406c93cb562a69ff11cf7adf34a4b223507dcb2b9e9b8f44ee1036988a +size 68 diff --git a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/param_stats.safetensors b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/param_stats.safetensors new file mode 100644 index 00000000..a6612b7f --- /dev/null +++ b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/param_stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5663ee79a13bb70a1604b887dd21bf89d18482287442419c6cc6c5bf0e753e99 +size 34928 diff --git a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/actions.safetensors b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/actions.safetensors new file mode 100644 index 00000000..9f0ba883 --- /dev/null +++ b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/actions.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fb1a45463efd860af2ca22c16c77d55a18bd96fef080ae77978845a2f22ef716 +size 5104 diff --git a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/grad_stats.safetensors new file mode 100644 index 00000000..2b01b94c --- /dev/null +++ b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/grad_stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aa5a43e22f01d8e2f8d19f31753608794f1edbd74aaf71660091ab80ea58dc9b +size 30808 diff --git a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/output_dict.safetensors b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/output_dict.safetensors new file mode 100644 index 00000000..c2417bf8 --- /dev/null +++ b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/output_dict.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:97455b4360748c99905cd103473c1a52da6901d0a73ffbc51b5ea3eb250d1386 +size 68 diff --git a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/param_stats.safetensors b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/param_stats.safetensors new file mode 100644 index 00000000..335d2a55 --- /dev/null +++ b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/param_stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:54d1f75cf67a7b1d7a7c6865ecb9b1cc86a2f032d1890245f8996789ab6e0df6 +size 33608 diff --git a/tests/scripts/save_policy_to_safetensor.py b/tests/scripts/save_policy_to_safetensors.py similarity index 91% rename from tests/scripts/save_policy_to_safetensor.py rename to tests/scripts/save_policy_to_safetensors.py index 89f33374..961b7cef 100644 --- a/tests/scripts/save_policy_to_safetensor.py +++ b/tests/scripts/save_policy_to_safetensors.py @@ -75,15 +75,16 @@ def get_policy_stats(env_name, policy_name, extra_overrides): # HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension dataset.delta_timestamps = None batch = next(iter(dataloader)) - obs = { - k: batch[k] - for k in batch - if k in ["observation.image", "observation.images.top", "observation.state"] - } + obs = {} + for k in batch: + if k.startswith("observation"): + obs[k] = batch[k] + + if "n_action_steps" in cfg.policy: + actions_queue = cfg.policy.n_action_steps + else: + actions_queue = cfg.policy.n_action_repeats - actions_queue = ( - cfg.policy.n_action_steps if "n_action_steps" in cfg.policy else cfg.policy.n_action_repeats - ) actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)} return output_dict, grad_stats, param_stats, actions @@ -114,6 +115,8 @@ if __name__ == "__main__": ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], ), ("aloha", "act", ["policy.n_action_steps=10"]), + ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]), + ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]), ] for env, policy, extra_overrides in env_policies: save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides) diff --git a/tests/test_policies.py b/tests/test_policies.py index bb0c7b80..c099bef0 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -30,7 +30,7 @@ from lerobot.common.policies.factory import get_policy_and_config_classes, make_ from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import init_hydra_config -from tests.scripts.save_policy_to_safetensor import get_policy_stats +from tests.scripts.save_policy_to_safetensors import get_policy_stats from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel @@ -72,6 +72,8 @@ def test_get_policy_and_config_classes(policy_name: str): ), # Note: these parameters also need custom logic in the test function for overriding the Hydra config. ("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]), + ("dora_aloha_real", "act_real", []), + ("dora_aloha_real", "act_real_no_state", []), ], ) @require_env @@ -84,6 +86,9 @@ def test_policy(env_name, policy_name, extra_overrides): - Updating the policy. - Using the policy to select actions at inference time. - Test the action can be applied to the policy + + Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive, + and for now we add tests as we see fit. """ cfg = init_hydra_config( DEFAULT_CONFIG_PATH, @@ -135,7 +140,7 @@ def test_policy(env_name, policy_name, extra_overrides): dataloader = torch.utils.data.DataLoader( dataset, - num_workers=4, + num_workers=0, batch_size=2, shuffle=True, pin_memory=DEVICE != "cpu", @@ -291,6 +296,8 @@ def test_normalize(insert_temporal_dim): ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], ), ("aloha", "act", ["policy.n_action_steps=10"]), + ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]), + ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]), ], ) # As artifacts have been generated on an x86_64 kernel, this test won't