Compare commits
9 Commits
main
...
fix_aloha_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9b5d2fd37d | ||
|
|
97ea288084 | ||
|
|
671ad93b6c | ||
|
|
b7b5c3b4ff | ||
|
|
1397036a6b | ||
|
|
c1570e40c6 | ||
|
|
8d847a58ef | ||
|
|
48f974bb9e | ||
|
|
511e39bdb8 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -121,7 +121,6 @@ celerybeat.pid
|
|||||||
# Environments
|
# Environments
|
||||||
.env
|
.env
|
||||||
.venv
|
.venv
|
||||||
env/
|
|
||||||
venv/
|
venv/
|
||||||
ENV/
|
ENV/
|
||||||
env.bak/
|
env.bak/
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ available_tasks_per_env = {
|
|||||||
],
|
],
|
||||||
"pusht": ["PushT-v0"],
|
"pusht": ["PushT-v0"],
|
||||||
"xarm": ["XarmLift-v0"],
|
"xarm": ["XarmLift-v0"],
|
||||||
|
"dora": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
|
||||||
}
|
}
|
||||||
available_envs = list(available_tasks_per_env.keys())
|
available_envs = list(available_tasks_per_env.keys())
|
||||||
|
|
||||||
@@ -77,6 +78,23 @@ available_datasets_per_env = {
|
|||||||
"lerobot/xarm_push_medium_image",
|
"lerobot/xarm_push_medium_image",
|
||||||
"lerobot/xarm_push_medium_replay_image",
|
"lerobot/xarm_push_medium_replay_image",
|
||||||
],
|
],
|
||||||
|
"dora": [
|
||||||
|
"lerobot/aloha_static_battery",
|
||||||
|
"lerobot/aloha_static_candy",
|
||||||
|
"lerobot/aloha_static_coffee",
|
||||||
|
"lerobot/aloha_static_coffee_new",
|
||||||
|
"lerobot/aloha_static_cups_open",
|
||||||
|
"lerobot/aloha_static_fork_pick_up",
|
||||||
|
"lerobot/aloha_static_pingpong_test",
|
||||||
|
"lerobot/aloha_static_pro_pencil",
|
||||||
|
"lerobot/aloha_static_screw_driver",
|
||||||
|
"lerobot/aloha_static_tape",
|
||||||
|
"lerobot/aloha_static_thread_velcro",
|
||||||
|
"lerobot/aloha_static_towel",
|
||||||
|
"lerobot/aloha_static_vinh_cup",
|
||||||
|
"lerobot/aloha_static_vinh_cup_left",
|
||||||
|
"lerobot/aloha_static_ziploc_slide",
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
available_real_world_datasets = [
|
available_real_world_datasets = [
|
||||||
@@ -116,6 +134,7 @@ available_policies = [
|
|||||||
|
|
||||||
available_policies_per_env = {
|
available_policies_per_env = {
|
||||||
"aloha": ["act"],
|
"aloha": ["act"],
|
||||||
|
"dora": ["act"],
|
||||||
"pusht": ["diffusion"],
|
"pusht": ["diffusion"],
|
||||||
"xarm": ["tdmpc"],
|
"xarm": ["tdmpc"],
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -78,15 +78,29 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
|
|||||||
|
|
||||||
image_keys = [key for key in df if "observation.images." in key]
|
image_keys = [key for key in df if "observation.images." in key]
|
||||||
|
|
||||||
|
num_unaligned_images = 0
|
||||||
|
max_episode = 0
|
||||||
|
|
||||||
def get_episode_index(row):
|
def get_episode_index(row):
|
||||||
|
nonlocal num_unaligned_images
|
||||||
|
nonlocal max_episode
|
||||||
episode_index_per_cam = {}
|
episode_index_per_cam = {}
|
||||||
for key in image_keys:
|
for key in image_keys:
|
||||||
|
if isinstance(row[key], float):
|
||||||
|
num_unaligned_images += 1
|
||||||
|
return float("nan")
|
||||||
path = row[key][0]["path"]
|
path = row[key][0]["path"]
|
||||||
match = re.search(r"_(\d{6}).mp4", path)
|
match = re.search(r"_(\d{6}).mp4", path)
|
||||||
if not match:
|
if not match:
|
||||||
raise ValueError(path)
|
raise ValueError(path)
|
||||||
episode_index = int(match.group(1))
|
episode_index = int(match.group(1))
|
||||||
episode_index_per_cam[key] = episode_index
|
episode_index_per_cam[key] = episode_index
|
||||||
|
|
||||||
|
if episode_index > max_episode:
|
||||||
|
assert episode_index - max_episode == 1
|
||||||
|
max_episode = episode_index
|
||||||
|
else:
|
||||||
|
assert episode_index == max_episode
|
||||||
if len(set(episode_index_per_cam.values())) != 1:
|
if len(set(episode_index_per_cam.values())) != 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"All cameras are expected to belong to the same episode, but getting {episode_index_per_cam}"
|
f"All cameras are expected to belong to the same episode, but getting {episode_index_per_cam}"
|
||||||
@@ -111,11 +125,24 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
|
|||||||
del df["timestamp_utc"]
|
del df["timestamp_utc"]
|
||||||
|
|
||||||
# sanity check
|
# sanity check
|
||||||
has_nan = df.isna().any().any()
|
num_rows_with_nan = df.isna().any(axis=1).sum()
|
||||||
if has_nan:
|
assert (
|
||||||
raise ValueError("Dataset contains Nan values.")
|
num_rows_with_nan == num_unaligned_images
|
||||||
|
), f"Found {num_rows_with_nan} rows with NaN values but {num_unaligned_images} unaligned images."
|
||||||
|
if num_unaligned_images > max_episode * 2:
|
||||||
|
# We allow a few unaligned images, typically at the beginning and end of the episodes for instance
|
||||||
|
# but if there are too many, we raise an error to avoid large chunks of missing data
|
||||||
|
raise ValueError(
|
||||||
|
f"Found {num_unaligned_images} unaligned images out of {max_episode} episodes. "
|
||||||
|
f"Check the timestamps of the cameras."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Drop rows with NaN values now that we double checked and convert episode_index to int
|
||||||
|
df = df.dropna()
|
||||||
|
df["episode_index"] = df["episode_index"].astype(int)
|
||||||
|
|
||||||
# sanity check episode indices go from 0 to n-1
|
# sanity check episode indices go from 0 to n-1
|
||||||
|
assert df["episode_index"].max() == max_episode
|
||||||
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
|
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
|
||||||
expected_ep_ids = list(range(df["episode_index"].max() + 1))
|
expected_ep_ids = list(range(df["episode_index"].max() + 1))
|
||||||
if ep_ids != expected_ep_ids:
|
if ep_ids != expected_ep_ids:
|
||||||
@@ -214,8 +241,6 @@ def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=Tru
|
|||||||
|
|
||||||
if fps is None:
|
if fps is None:
|
||||||
fps = 30
|
fps = 30
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
if not video:
|
if not video:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|||||||
@@ -25,6 +25,14 @@ class ACTConfig:
|
|||||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||||
Those are: `input_shapes` and 'output_shapes`.
|
Those are: `input_shapes` and 'output_shapes`.
|
||||||
|
|
||||||
|
Notes on the inputs and outputs:
|
||||||
|
- "observation.state" is required as an input key.
|
||||||
|
- At least one key starting with "observation.image is required as an input.
|
||||||
|
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
|
||||||
|
views.
|
||||||
|
Right now we only support all images having the same shape.
|
||||||
|
- "action" is required as an output key.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||||
current step and additional steps going back).
|
current step and additional steps going back).
|
||||||
@@ -33,15 +41,15 @@ class ACTConfig:
|
|||||||
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
|
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
|
||||||
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
|
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
|
||||||
environment, and throws the other 50 out.
|
environment, and throws the other 50 out.
|
||||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||||
The key represents the input data name, and the value is a list indicating the dimensions
|
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
of the corresponding data. For example, "observation.images.top" refers to an input from the
|
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||||
"top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||||
Importantly, shapes doesn't include batch dimension or temporal dimension.
|
include batch dimension or temporal dimension.
|
||||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||||
The key represents the output data name, and the value is a list indicating the dimensions
|
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||||
14-dimensional actions. Importantly, shapes doesn't include batch dimension or temporal dimension.
|
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||||
|
|||||||
@@ -200,25 +200,28 @@ class ACT(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
||||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||||
|
self.has_state = "observation.state" in config.input_shapes
|
||||||
|
self.latent_dim = config.latent_dim
|
||||||
if self.config.use_vae:
|
if self.config.use_vae:
|
||||||
self.vae_encoder = ACTEncoder(config)
|
self.vae_encoder = ACTEncoder(config)
|
||||||
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
|
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
|
||||||
# Projection layer for joint-space configuration to hidden dimension.
|
# Projection layer for joint-space configuration to hidden dimension.
|
||||||
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
if self.has_state:
|
||||||
config.input_shapes["observation.state"][0], config.dim_model
|
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
||||||
)
|
config.input_shapes["observation.state"][0], config.dim_model
|
||||||
|
)
|
||||||
# Projection layer for action (joint-space target) to hidden dimension.
|
# Projection layer for action (joint-space target) to hidden dimension.
|
||||||
self.vae_encoder_action_input_proj = nn.Linear(
|
self.vae_encoder_action_input_proj = nn.Linear(
|
||||||
config.input_shapes["observation.state"][0], config.dim_model
|
config.output_shapes["action"][0], config.dim_model
|
||||||
)
|
)
|
||||||
self.latent_dim = config.latent_dim
|
|
||||||
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
||||||
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2)
|
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2)
|
||||||
# Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch
|
# Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch
|
||||||
# dimension.
|
# dimension.
|
||||||
|
num_input_token_encoder = 1 + 1 + config.chunk_size if self.has_state else 1 + config.chunk_size
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"vae_encoder_pos_enc",
|
"vae_encoder_pos_enc",
|
||||||
create_sinusoidal_pos_embedding(1 + 1 + config.chunk_size, config.dim_model).unsqueeze(0),
|
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Backbone for image feature extraction.
|
# Backbone for image feature extraction.
|
||||||
@@ -238,15 +241,17 @@ class ACT(nn.Module):
|
|||||||
|
|
||||||
# Transformer encoder input projections. The tokens will be structured like
|
# Transformer encoder input projections. The tokens will be structured like
|
||||||
# [latent, robot_state, image_feature_map_pixels].
|
# [latent, robot_state, image_feature_map_pixels].
|
||||||
self.encoder_robot_state_input_proj = nn.Linear(
|
if self.has_state:
|
||||||
config.input_shapes["observation.state"][0], config.dim_model
|
self.encoder_robot_state_input_proj = nn.Linear(
|
||||||
)
|
config.input_shapes["observation.state"][0], config.dim_model
|
||||||
|
)
|
||||||
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model)
|
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model)
|
||||||
self.encoder_img_feat_input_proj = nn.Conv2d(
|
self.encoder_img_feat_input_proj = nn.Conv2d(
|
||||||
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
||||||
)
|
)
|
||||||
# Transformer encoder positional embeddings.
|
# Transformer encoder positional embeddings.
|
||||||
self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, config.dim_model)
|
num_input_token_decoder = 2 if self.has_state else 1
|
||||||
|
self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model)
|
||||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||||
|
|
||||||
# Transformer decoder.
|
# Transformer decoder.
|
||||||
@@ -285,7 +290,7 @@ class ACT(nn.Module):
|
|||||||
"action" in batch
|
"action" in batch
|
||||||
), "actions must be provided when using the variational objective in training mode."
|
), "actions must be provided when using the variational objective in training mode."
|
||||||
|
|
||||||
batch_size = batch["observation.state"].shape[0]
|
batch_size = batch["observation.images"].shape[0]
|
||||||
|
|
||||||
# Prepare the latent for input to the transformer encoder.
|
# Prepare the latent for input to the transformer encoder.
|
||||||
if self.config.use_vae and "action" in batch:
|
if self.config.use_vae and "action" in batch:
|
||||||
@@ -293,11 +298,16 @@ class ACT(nn.Module):
|
|||||||
cls_embed = einops.repeat(
|
cls_embed = einops.repeat(
|
||||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||||
) # (B, 1, D)
|
) # (B, 1, D)
|
||||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze(
|
if self.has_state:
|
||||||
1
|
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
||||||
) # (B, 1, D)
|
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
||||||
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||||
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
|
|
||||||
|
if self.has_state:
|
||||||
|
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
||||||
|
else:
|
||||||
|
vae_encoder_input = [cls_embed, action_embed]
|
||||||
|
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
|
||||||
|
|
||||||
# Prepare fixed positional embedding.
|
# Prepare fixed positional embedding.
|
||||||
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
|
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
|
||||||
@@ -317,6 +327,7 @@ class ACT(nn.Module):
|
|||||||
else:
|
else:
|
||||||
# When not using the VAE encoder, we set the latent to be all zeros.
|
# When not using the VAE encoder, we set the latent to be all zeros.
|
||||||
mu = log_sigma_x2 = None
|
mu = log_sigma_x2 = None
|
||||||
|
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||||
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
|
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
|
||||||
batch["observation.state"].device
|
batch["observation.state"].device
|
||||||
)
|
)
|
||||||
@@ -326,8 +337,10 @@ class ACT(nn.Module):
|
|||||||
all_cam_features = []
|
all_cam_features = []
|
||||||
all_cam_pos_embeds = []
|
all_cam_pos_embeds = []
|
||||||
images = batch["observation.images"]
|
images = batch["observation.images"]
|
||||||
|
|
||||||
for cam_index in range(images.shape[-4]):
|
for cam_index in range(images.shape[-4]):
|
||||||
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
||||||
|
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
||||||
all_cam_features.append(cam_features)
|
all_cam_features.append(cam_features)
|
||||||
@@ -337,13 +350,15 @@ class ACT(nn.Module):
|
|||||||
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
|
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
|
||||||
|
|
||||||
# Get positional embeddings for robot state and latent.
|
# Get positional embeddings for robot state and latent.
|
||||||
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
|
if self.has_state:
|
||||||
|
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
|
||||||
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
|
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
|
||||||
|
|
||||||
# Stack encoder input and positional embeddings moving to (S, B, C).
|
# Stack encoder input and positional embeddings moving to (S, B, C).
|
||||||
|
encoder_in_feats = [latent_embed, robot_state_embed] if self.has_state else [latent_embed]
|
||||||
encoder_in = torch.cat(
|
encoder_in = torch.cat(
|
||||||
[
|
[
|
||||||
torch.stack([latent_embed, robot_state_embed], axis=0),
|
torch.stack(encoder_in_feats, axis=0),
|
||||||
einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
|
einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -357,6 +372,7 @@ class ACT(nn.Module):
|
|||||||
|
|
||||||
# Forward pass through the transformer modules.
|
# Forward pass through the transformer modules.
|
||||||
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
|
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
|
||||||
|
# TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
|
||||||
decoder_in = torch.zeros(
|
decoder_in = torch.zeros(
|
||||||
(self.config.chunk_size, batch_size, self.config.dim_model),
|
(self.config.chunk_size, batch_size, self.config.dim_model),
|
||||||
dtype=pos_embed.dtype,
|
dtype=pos_embed.dtype,
|
||||||
|
|||||||
@@ -26,21 +26,29 @@ class DiffusionConfig:
|
|||||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||||
Those are: `input_shapes` and `output_shapes`.
|
Those are: `input_shapes` and `output_shapes`.
|
||||||
|
|
||||||
|
Notes on the inputs and outputs:
|
||||||
|
- "observation.state" is required as an input key.
|
||||||
|
- At least one key starting with "observation.image is required as an input.
|
||||||
|
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
|
||||||
|
views.
|
||||||
|
Right now we only support all images having the same shape.
|
||||||
|
- "action" is required as an output key.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||||
current step and additional steps going back).
|
current step and additional steps going back).
|
||||||
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||||
See `DiffusionPolicy.select_action` for more details.
|
See `DiffusionPolicy.select_action` for more details.
|
||||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||||
The key represents the input data name, and the value is a list indicating the dimensions
|
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
of the corresponding data. For example, "observation.image" refers to an input from
|
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||||
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||||
Importantly, shapes doesnt include batch dimension or temporal dimension.
|
include batch dimension or temporal dimension.
|
||||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||||
The key represents the output data name, and the value is a list indicating the dimensions
|
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||||
|
|||||||
@@ -31,6 +31,15 @@ class TDMPCConfig:
|
|||||||
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
|
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
|
||||||
action repeats in Q-learning or ask your favorite chatbot)
|
action repeats in Q-learning or ask your favorite chatbot)
|
||||||
horizon: Horizon for model predictive control.
|
horizon: Horizon for model predictive control.
|
||||||
|
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||||
|
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
|
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||||
|
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||||
|
include batch dimension or temporal dimension.
|
||||||
|
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||||
|
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
|
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||||
|
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||||
|
|||||||
13
lerobot/configs/env/dora_aloha_real.yaml
vendored
Normal file
13
lerobot/configs/env/dora_aloha_real.yaml
vendored
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
fps: 30
|
||||||
|
|
||||||
|
env:
|
||||||
|
name: dora
|
||||||
|
task: DoraAloha-v0
|
||||||
|
state_dim: 14
|
||||||
|
action_dim: 14
|
||||||
|
fps: ${fps}
|
||||||
|
episode_length: 400
|
||||||
|
gym:
|
||||||
|
fps: ${fps}
|
||||||
115
lerobot/configs/policy/act_real.yaml
Normal file
115
lerobot/configs/policy/act_real.yaml
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
|
||||||
|
# Compared to `act.yaml`, it contains 4 cameras (i.e. cam_right_wrist, cam_left_wrist, images,
|
||||||
|
# cam_low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
|
||||||
|
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
|
||||||
|
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
|
||||||
|
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
|
||||||
|
#
|
||||||
|
# Example of usage for training:
|
||||||
|
# ```bash
|
||||||
|
# python lerobot/scripts/train.py \
|
||||||
|
# policy=act_real \
|
||||||
|
# env=aloha_real
|
||||||
|
# ```
|
||||||
|
|
||||||
|
seed: 1000
|
||||||
|
dataset_repo_id: lerobot/aloha_static_vinh_cup
|
||||||
|
|
||||||
|
override_dataset_stats:
|
||||||
|
observation.images.cam_right_wrist:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
observation.images.cam_left_wrist:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
observation.images.cam_high:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
observation.images.cam_low:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
|
||||||
|
training:
|
||||||
|
offline_steps: 80000
|
||||||
|
online_steps: 0
|
||||||
|
eval_freq: -1
|
||||||
|
save_freq: 10000
|
||||||
|
log_freq: 100
|
||||||
|
save_checkpoint: true
|
||||||
|
|
||||||
|
batch_size: 8
|
||||||
|
lr: 1e-5
|
||||||
|
lr_backbone: 1e-5
|
||||||
|
weight_decay: 1e-4
|
||||||
|
grad_clip_norm: 10
|
||||||
|
online_steps_between_rollouts: 1
|
||||||
|
|
||||||
|
delta_timestamps:
|
||||||
|
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||||
|
|
||||||
|
eval:
|
||||||
|
n_episodes: 50
|
||||||
|
batch_size: 50
|
||||||
|
|
||||||
|
# See `configuration_act.py` for more details.
|
||||||
|
policy:
|
||||||
|
name: act
|
||||||
|
|
||||||
|
# Input / output structure.
|
||||||
|
n_obs_steps: 1
|
||||||
|
chunk_size: 100 # chunk_size
|
||||||
|
n_action_steps: 100
|
||||||
|
|
||||||
|
input_shapes:
|
||||||
|
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||||
|
observation.images.cam_right_wrist: [3, 480, 640]
|
||||||
|
observation.images.cam_left_wrist: [3, 480, 640]
|
||||||
|
observation.images.cam_high: [3, 480, 640]
|
||||||
|
observation.images.cam_low: [3, 480, 640]
|
||||||
|
observation.state: ["${env.state_dim}"]
|
||||||
|
output_shapes:
|
||||||
|
action: ["${env.action_dim}"]
|
||||||
|
|
||||||
|
# Normalization / Unnormalization
|
||||||
|
input_normalization_modes:
|
||||||
|
observation.images.cam_right_wrist: mean_std
|
||||||
|
observation.images.cam_left_wrist: mean_std
|
||||||
|
observation.images.cam_high: mean_std
|
||||||
|
observation.images.cam_low: mean_std
|
||||||
|
observation.state: mean_std
|
||||||
|
output_normalization_modes:
|
||||||
|
action: mean_std
|
||||||
|
|
||||||
|
# Architecture.
|
||||||
|
# Vision backbone.
|
||||||
|
vision_backbone: resnet18
|
||||||
|
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||||
|
replace_final_stride_with_dilation: false
|
||||||
|
# Transformer layers.
|
||||||
|
pre_norm: false
|
||||||
|
dim_model: 512
|
||||||
|
n_heads: 8
|
||||||
|
dim_feedforward: 3200
|
||||||
|
feedforward_activation: relu
|
||||||
|
n_encoder_layers: 4
|
||||||
|
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||||
|
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||||
|
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||||
|
n_decoder_layers: 1
|
||||||
|
# VAE.
|
||||||
|
use_vae: true
|
||||||
|
latent_dim: 32
|
||||||
|
n_vae_encoder_layers: 4
|
||||||
|
|
||||||
|
# Inference.
|
||||||
|
temporal_ensemble_momentum: null
|
||||||
|
|
||||||
|
# Training and loss computation.
|
||||||
|
dropout: 0.1
|
||||||
|
kl_weight: 10.0
|
||||||
111
lerobot/configs/policy/act_real_no_state.yaml
Normal file
111
lerobot/configs/policy/act_real_no_state.yaml
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
# Use `act_real_no_state.yaml` to train on real-world Aloha/Aloha2 datasets when cameras are moving (e.g. wrist cameras)
|
||||||
|
# Compared to `act_real.yaml`, it is camera only and does not use the state as input which is vector of robot joint positions.
|
||||||
|
# We validated experimentaly that not using state reaches better success rate. Our hypothesis is that `act_real.yaml` might
|
||||||
|
# overfits to the state, because the images are more complex to learn from since they are moving.
|
||||||
|
#
|
||||||
|
# Example of usage for training:
|
||||||
|
# ```bash
|
||||||
|
# python lerobot/scripts/train.py \
|
||||||
|
# policy=act_real_no_state \
|
||||||
|
# env=aloha_real
|
||||||
|
# ```
|
||||||
|
|
||||||
|
seed: 1000
|
||||||
|
dataset_repo_id: lerobot/aloha_static_vinh_cup
|
||||||
|
|
||||||
|
override_dataset_stats:
|
||||||
|
observation.images.cam_right_wrist:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
observation.images.cam_left_wrist:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
observation.images.cam_high:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
observation.images.cam_low:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
|
||||||
|
training:
|
||||||
|
offline_steps: 80000
|
||||||
|
online_steps: 0
|
||||||
|
eval_freq: -1
|
||||||
|
save_freq: 10000
|
||||||
|
log_freq: 100
|
||||||
|
save_checkpoint: true
|
||||||
|
|
||||||
|
batch_size: 8
|
||||||
|
lr: 1e-5
|
||||||
|
lr_backbone: 1e-5
|
||||||
|
weight_decay: 1e-4
|
||||||
|
grad_clip_norm: 10
|
||||||
|
online_steps_between_rollouts: 1
|
||||||
|
|
||||||
|
delta_timestamps:
|
||||||
|
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||||
|
|
||||||
|
eval:
|
||||||
|
n_episodes: 50
|
||||||
|
batch_size: 50
|
||||||
|
|
||||||
|
# See `configuration_act.py` for more details.
|
||||||
|
policy:
|
||||||
|
name: act
|
||||||
|
|
||||||
|
# Input / output structure.
|
||||||
|
n_obs_steps: 1
|
||||||
|
chunk_size: 100 # chunk_size
|
||||||
|
n_action_steps: 100
|
||||||
|
|
||||||
|
input_shapes:
|
||||||
|
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||||
|
observation.images.cam_right_wrist: [3, 480, 640]
|
||||||
|
observation.images.cam_left_wrist: [3, 480, 640]
|
||||||
|
observation.images.cam_high: [3, 480, 640]
|
||||||
|
observation.images.cam_low: [3, 480, 640]
|
||||||
|
output_shapes:
|
||||||
|
action: ["${env.action_dim}"]
|
||||||
|
|
||||||
|
# Normalization / Unnormalization
|
||||||
|
input_normalization_modes:
|
||||||
|
observation.images.cam_right_wrist: mean_std
|
||||||
|
observation.images.cam_left_wrist: mean_std
|
||||||
|
observation.images.cam_high: mean_std
|
||||||
|
observation.images.cam_low: mean_std
|
||||||
|
output_normalization_modes:
|
||||||
|
action: mean_std
|
||||||
|
|
||||||
|
# Architecture.
|
||||||
|
# Vision backbone.
|
||||||
|
vision_backbone: resnet18
|
||||||
|
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||||
|
replace_final_stride_with_dilation: false
|
||||||
|
# Transformer layers.
|
||||||
|
pre_norm: false
|
||||||
|
dim_model: 512
|
||||||
|
n_heads: 8
|
||||||
|
dim_feedforward: 3200
|
||||||
|
feedforward_activation: relu
|
||||||
|
n_encoder_layers: 4
|
||||||
|
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||||
|
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||||
|
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||||
|
n_decoder_layers: 1
|
||||||
|
# VAE.
|
||||||
|
use_vae: true
|
||||||
|
latent_dim: 32
|
||||||
|
n_vae_encoder_layers: 4
|
||||||
|
|
||||||
|
# Inference.
|
||||||
|
temporal_ensemble_momentum: null
|
||||||
|
|
||||||
|
# Training and loss computation.
|
||||||
|
dropout: 0.1
|
||||||
|
kl_weight: 10.0
|
||||||
48
poetry.lock
generated
48
poetry.lock
generated
@@ -1,4 +1,4 @@
|
|||||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "absl-py"
|
name = "absl-py"
|
||||||
@@ -785,6 +785,26 @@ files = [
|
|||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
six = ">=1.4.0"
|
six = ">=1.4.0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "dora-rs"
|
||||||
|
version = "0.3.4"
|
||||||
|
description = "`dora` goal is to be a low latency, composable, and distributed data flow."
|
||||||
|
optional = true
|
||||||
|
python-versions = "*"
|
||||||
|
files = [
|
||||||
|
{file = "dora_rs-0.3.4-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:d1b738eea5a4966d731c26c6b6a0a50a491a24f7e9e335475f983cfc6f0da19e"},
|
||||||
|
{file = "dora_rs-0.3.4-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:80b724871618c78a4e5863938fa66724176cc40352771087aebe1e62a8141157"},
|
||||||
|
{file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a3919e157b47dc1dbc74c040a73087a4485f0d1bee99b6adcdbc36559400fe2"},
|
||||||
|
{file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f7c95f6e5858fd651d6cd220e4f052e99db2944b9c37fb0b5402d60ac4b41a63"},
|
||||||
|
{file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37d915fbbca282446235c98a9ca08389aa3ef3155d4e88c6c136326e9a830042"},
|
||||||
|
{file = "dora_rs-0.3.4-cp37-abi3-win32.whl", hash = "sha256:c9f7f22f65c884ec9bee0245ce98d0c7fad25dec0f982e566f844b5e8e58818f"},
|
||||||
|
{file = "dora_rs-0.3.4-cp37-abi3-win_amd64.whl", hash = "sha256:0a6a37f96a9f6e13b58b02a6ea75af192af5fbe4f456f6a67b1f239c3cee3276"},
|
||||||
|
{file = "dora_rs-0.3.4.tar.gz", hash = "sha256:05c5d0db0d23d7c4669995ae34db11cd636dbf91f5705d832669bd04e7452903"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
pyarrow = "*"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "einops"
|
name = "einops"
|
||||||
version = "0.8.0"
|
version = "0.8.0"
|
||||||
@@ -1066,6 +1086,27 @@ mujoco = ">=2.3.7,<3.0.0"
|
|||||||
dev = ["debugpy (>=1.8.1)", "pre-commit (>=3.7.0)"]
|
dev = ["debugpy (>=1.8.1)", "pre-commit (>=3.7.0)"]
|
||||||
test = ["pytest (>=8.1.0)", "pytest-cov (>=5.0.0)"]
|
test = ["pytest (>=8.1.0)", "pytest-cov (>=5.0.0)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "gym-dora"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = ""
|
||||||
|
optional = true
|
||||||
|
python-versions = "^3.10"
|
||||||
|
files = []
|
||||||
|
develop = false
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
dora-rs = ">=0.3.4"
|
||||||
|
gymnasium = ">=0.29.1"
|
||||||
|
pyarrow = ">=12.0.0"
|
||||||
|
|
||||||
|
[package.source]
|
||||||
|
type = "git"
|
||||||
|
url = "https://github.com/dora-rs/dora-lerobot.git"
|
||||||
|
reference = "HEAD"
|
||||||
|
resolved_reference = "1c6c2a401c3a2967d41444be6286ca9a28893abf"
|
||||||
|
subdirectory = "gym_dora"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "gym-pusht"
|
name = "gym-pusht"
|
||||||
version = "0.1.4"
|
version = "0.1.4"
|
||||||
@@ -2406,6 +2447,7 @@ optional = false
|
|||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
files = [
|
files = [
|
||||||
{file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"},
|
{file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"},
|
||||||
|
{file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"},
|
||||||
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"},
|
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"},
|
||||||
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"},
|
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"},
|
||||||
{file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"},
|
{file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"},
|
||||||
@@ -2426,6 +2468,7 @@ files = [
|
|||||||
{file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"},
|
{file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"},
|
||||||
{file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"},
|
{file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"},
|
||||||
{file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"},
|
{file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"},
|
||||||
|
{file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"},
|
||||||
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"},
|
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"},
|
||||||
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"},
|
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"},
|
||||||
{file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"},
|
{file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"},
|
||||||
@@ -4257,6 +4300,7 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more
|
|||||||
[extras]
|
[extras]
|
||||||
aloha = ["gym-aloha"]
|
aloha = ["gym-aloha"]
|
||||||
dev = ["debugpy", "pre-commit"]
|
dev = ["debugpy", "pre-commit"]
|
||||||
|
dora = ["gym-dora"]
|
||||||
pusht = ["gym-pusht"]
|
pusht = ["gym-pusht"]
|
||||||
test = ["pytest", "pytest-cov"]
|
test = ["pytest", "pytest-cov"]
|
||||||
umi = ["imagecodecs"]
|
umi = ["imagecodecs"]
|
||||||
@@ -4265,4 +4309,4 @@ xarm = ["gym-xarm"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.10,<3.13"
|
python-versions = ">=3.10,<3.13"
|
||||||
content-hash = "1ad6ef0f88f0056ab639e60e033e586f7460a9c5fc3676a477bbd47923f41cb6"
|
content-hash = "23ddb8dd774a4faf85d08a07dfdf19badb7c370120834b71df4afca254520771"
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ h5py = ">=3.10.0"
|
|||||||
huggingface-hub = {extras = ["hf-transfer"], version = "^0.23.0"}
|
huggingface-hub = {extras = ["hf-transfer"], version = "^0.23.0"}
|
||||||
gymnasium = ">=0.29.1"
|
gymnasium = ">=0.29.1"
|
||||||
cmake = ">=3.29.0.1"
|
cmake = ">=3.29.0.1"
|
||||||
|
gym-dora = { git = "https://github.com/dora-rs/dora-lerobot.git", subdirectory = "gym_dora", optional = true }
|
||||||
gym-pusht = { version = ">=0.1.3", optional = true}
|
gym-pusht = { version = ">=0.1.3", optional = true}
|
||||||
gym-xarm = { version = ">=0.1.1", optional = true}
|
gym-xarm = { version = ">=0.1.1", optional = true}
|
||||||
gym-aloha = { version = ">=0.1.1", optional = true}
|
gym-aloha = { version = ">=0.1.1", optional = true}
|
||||||
@@ -62,6 +63,7 @@ deepdiff = ">=7.0.1"
|
|||||||
|
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
|
dora = ["gym-dora"]
|
||||||
pusht = ["gym-pusht"]
|
pusht = ["gym-pusht"]
|
||||||
xarm = ["gym-xarm"]
|
xarm = ["gym-xarm"]
|
||||||
aloha = ["gym-aloha"]
|
aloha = ["gym-aloha"]
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:ebd21273f6048b66c806f92035352843a9069908b3296863fd55d34cf71cd0ef
|
||||||
|
size 51248
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:b9bbf951891077320a5da27e77ddb580a6e833e8d3162b62a2f887a1989585cc
|
||||||
|
size 31688
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:d4070bd1f1cd8c72bc2daf628088e42b8ef113f6df0bfd9e91be052bc90038c3
|
||||||
|
size 68
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:42f92239223bb4df32d5c3016bc67450159f1285a7ab046307b645f699ccc34e
|
||||||
|
size 34928
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:52f85d6262ad1dd0b66578b25829fed96aaaca3c7458cb73ac75111350d17fcf
|
||||||
|
size 51248
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:5ba7c910618f0f3ca69f82f3d70c880d2b2e432456524a2a63dfd5c50efa45f0
|
||||||
|
size 30808
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:97455b4360748c99905cd103473c1a52da6901d0a73ffbc51b5ea3eb250d1386
|
||||||
|
size 68
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:53ad410f43855254438790f54aa7c895a052776acdd922906ae430684f659b53
|
||||||
|
size 33608
|
||||||
@@ -75,15 +75,16 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
|
|||||||
# HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension
|
# HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension
|
||||||
dataset.delta_timestamps = None
|
dataset.delta_timestamps = None
|
||||||
batch = next(iter(dataloader))
|
batch = next(iter(dataloader))
|
||||||
obs = {
|
obs = {}
|
||||||
k: batch[k]
|
for k in batch:
|
||||||
for k in batch
|
if "observation" in k:
|
||||||
if k in ["observation.image", "observation.images.top", "observation.state"]
|
obs[k] = batch[k]
|
||||||
}
|
|
||||||
|
if "n_action_steps" in cfg.policy:
|
||||||
|
actions_queue = cfg.policy.n_action_steps
|
||||||
|
else:
|
||||||
|
actions_queue = cfg.policy.n_action_repeats
|
||||||
|
|
||||||
actions_queue = (
|
|
||||||
cfg.policy.n_action_steps if "n_action_steps" in cfg.policy else cfg.policy.n_action_repeats
|
|
||||||
)
|
|
||||||
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
|
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
|
||||||
return output_dict, grad_stats, param_stats, actions
|
return output_dict, grad_stats, param_stats, actions
|
||||||
|
|
||||||
@@ -114,6 +115,8 @@ if __name__ == "__main__":
|
|||||||
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||||
),
|
),
|
||||||
("aloha", "act", ["policy.n_action_steps=10"]),
|
("aloha", "act", ["policy.n_action_steps=10"]),
|
||||||
|
("dora_aloha_real", "act_real", []),
|
||||||
|
("dora_aloha_real", "act_real_no_state", []),
|
||||||
]
|
]
|
||||||
for env, policy, extra_overrides in env_policies:
|
for env, policy, extra_overrides in env_policies:
|
||||||
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)
|
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)
|
||||||
@@ -30,7 +30,7 @@ from lerobot.common.policies.factory import get_policy_and_config_classes, make_
|
|||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.policy_protocol import Policy
|
from lerobot.common.policies.policy_protocol import Policy
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
from tests.scripts.save_policy_to_safetensor import get_policy_stats
|
from tests.scripts.save_policy_to_safetensors import get_policy_stats
|
||||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
|
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||||
|
|
||||||
|
|
||||||
@@ -72,6 +72,8 @@ def test_get_policy_and_config_classes(policy_name: str):
|
|||||||
),
|
),
|
||||||
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
|
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
|
||||||
("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]),
|
("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]),
|
||||||
|
("dora_aloha_real", "act_real", []),
|
||||||
|
("dora_aloha_real", "act_real_no_state", []),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@require_env
|
@require_env
|
||||||
@@ -291,6 +293,8 @@ def test_normalize(insert_temporal_dim):
|
|||||||
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||||
),
|
),
|
||||||
("aloha", "act", ["policy.n_action_steps=10"]),
|
("aloha", "act", ["policy.n_action_steps=10"]),
|
||||||
|
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
|
||||||
|
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
# As artifacts have been generated on an x86_64 kernel, this test won't
|
# As artifacts have been generated on an x86_64 kernel, this test won't
|
||||||
|
|||||||
Reference in New Issue
Block a user