Compare commits
9 Commits
depth
...
thomwolf_2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eac660bb9e | ||
|
|
1333560f6b | ||
|
|
92d1aecb40 | ||
|
|
03d237fe0f | ||
|
|
1c9f447ad0 | ||
|
|
73f1a3932d | ||
|
|
97bda08e0f | ||
|
|
51dea3f67c | ||
|
|
5495d55cc7 |
@@ -55,7 +55,7 @@ available_tasks_per_env = {
|
|||||||
],
|
],
|
||||||
"pusht": ["PushT-v0"],
|
"pusht": ["PushT-v0"],
|
||||||
"xarm": ["XarmLift-v0"],
|
"xarm": ["XarmLift-v0"],
|
||||||
"dora_aloha_real": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
|
"dora": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
|
||||||
}
|
}
|
||||||
available_envs = list(available_tasks_per_env.keys())
|
available_envs = list(available_tasks_per_env.keys())
|
||||||
|
|
||||||
@@ -81,7 +81,7 @@ available_datasets_per_env = {
|
|||||||
"lerobot/xarm_push_medium_image",
|
"lerobot/xarm_push_medium_image",
|
||||||
"lerobot/xarm_push_medium_replay_image",
|
"lerobot/xarm_push_medium_replay_image",
|
||||||
],
|
],
|
||||||
"dora_aloha_real": [
|
"dora": [
|
||||||
"lerobot/aloha_static_battery",
|
"lerobot/aloha_static_battery",
|
||||||
"lerobot/aloha_static_candy",
|
"lerobot/aloha_static_candy",
|
||||||
"lerobot/aloha_static_coffee",
|
"lerobot/aloha_static_coffee",
|
||||||
@@ -139,6 +139,7 @@ available_policies = [
|
|||||||
# keys and values refer to yaml files
|
# keys and values refer to yaml files
|
||||||
available_policies_per_env = {
|
available_policies_per_env = {
|
||||||
"aloha": ["act"],
|
"aloha": ["act"],
|
||||||
|
"dora": ["act"],
|
||||||
"pusht": ["diffusion"],
|
"pusht": ["diffusion"],
|
||||||
"xarm": ["tdmpc"],
|
"xarm": ["tdmpc"],
|
||||||
"dora_aloha_real": ["act_real"],
|
"dora_aloha_real": ["act_real"],
|
||||||
|
|||||||
@@ -78,15 +78,29 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
|
|||||||
|
|
||||||
image_keys = [key for key in df if "observation.images." in key]
|
image_keys = [key for key in df if "observation.images." in key]
|
||||||
|
|
||||||
|
num_unaligned_images = 0
|
||||||
|
max_episode = 0
|
||||||
|
|
||||||
def get_episode_index(row):
|
def get_episode_index(row):
|
||||||
|
nonlocal num_unaligned_images
|
||||||
|
nonlocal max_episode
|
||||||
episode_index_per_cam = {}
|
episode_index_per_cam = {}
|
||||||
for key in image_keys:
|
for key in image_keys:
|
||||||
|
if isinstance(row[key], float):
|
||||||
|
num_unaligned_images += 1
|
||||||
|
return float("nan")
|
||||||
path = row[key][0]["path"]
|
path = row[key][0]["path"]
|
||||||
match = re.search(r"_(\d{6}).mp4", path)
|
match = re.search(r"_(\d{6}).mp4", path)
|
||||||
if not match:
|
if not match:
|
||||||
raise ValueError(path)
|
raise ValueError(path)
|
||||||
episode_index = int(match.group(1))
|
episode_index = int(match.group(1))
|
||||||
episode_index_per_cam[key] = episode_index
|
episode_index_per_cam[key] = episode_index
|
||||||
|
|
||||||
|
if episode_index > max_episode:
|
||||||
|
assert episode_index - max_episode == 1
|
||||||
|
max_episode = episode_index
|
||||||
|
else:
|
||||||
|
assert episode_index == max_episode
|
||||||
if len(set(episode_index_per_cam.values())) != 1:
|
if len(set(episode_index_per_cam.values())) != 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"All cameras are expected to belong to the same episode, but getting {episode_index_per_cam}"
|
f"All cameras are expected to belong to the same episode, but getting {episode_index_per_cam}"
|
||||||
@@ -111,11 +125,24 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
|
|||||||
del df["timestamp_utc"]
|
del df["timestamp_utc"]
|
||||||
|
|
||||||
# sanity check
|
# sanity check
|
||||||
has_nan = df.isna().any().any()
|
num_rows_with_nan = df.isna().any(axis=1).sum()
|
||||||
if has_nan:
|
assert (
|
||||||
raise ValueError("Dataset contains Nan values.")
|
num_rows_with_nan == num_unaligned_images
|
||||||
|
), f"Found {num_rows_with_nan} rows with NaN values but {num_unaligned_images} unaligned images."
|
||||||
|
if num_unaligned_images > max_episode * 2:
|
||||||
|
# We allow a few unaligned images, typically at the beginning and end of the episodes for instance
|
||||||
|
# but if there are too many, we raise an error to avoid large chunks of missing data
|
||||||
|
raise ValueError(
|
||||||
|
f"Found {num_unaligned_images} unaligned images out of {max_episode} episodes. "
|
||||||
|
f"Check the timestamps of the cameras."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Drop rows with NaN values now that we double checked and convert episode_index to int
|
||||||
|
df = df.dropna()
|
||||||
|
df["episode_index"] = df["episode_index"].astype(int)
|
||||||
|
|
||||||
# sanity check episode indices go from 0 to n-1
|
# sanity check episode indices go from 0 to n-1
|
||||||
|
assert df["episode_index"].max() == max_episode
|
||||||
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
|
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
|
||||||
expected_ep_ids = list(range(df["episode_index"].max() + 1))
|
expected_ep_ids = list(range(df["episode_index"].max() + 1))
|
||||||
if ep_ids != expected_ep_ids:
|
if ep_ids != expected_ep_ids:
|
||||||
@@ -214,8 +241,6 @@ def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=Tru
|
|||||||
|
|
||||||
if fps is None:
|
if fps is None:
|
||||||
fps = 30
|
fps = 30
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
if not video:
|
if not video:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|||||||
@@ -243,10 +243,11 @@ def load_previous_and_future_frames(
|
|||||||
is_pad = min_ > tolerance_s
|
is_pad = min_ > tolerance_s
|
||||||
|
|
||||||
# check violated query timestamps are all outside the episode range
|
# check violated query timestamps are all outside the episode range
|
||||||
assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
|
if not ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all():
|
||||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tolerance_s=}) inside episode range."
|
raise ValueError(
|
||||||
"This might be due to synchronization issues with timestamps during data collection."
|
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tolerance_s=}) inside episode range."
|
||||||
)
|
"This might be due to synchronization issues with timestamps during data collection."
|
||||||
|
)
|
||||||
|
|
||||||
# get dataset indices corresponding to frames to be loaded
|
# get dataset indices corresponding to frames to be loaded
|
||||||
data_ids = ep_data_ids[argmin_]
|
data_ids = ep_data_ids[argmin_]
|
||||||
|
|||||||
@@ -189,7 +189,7 @@ class Logger:
|
|||||||
training_state["scheduler"] = scheduler.state_dict()
|
training_state["scheduler"] = scheduler.state_dict()
|
||||||
torch.save(training_state, save_dir / self.training_state_file_name)
|
torch.save(training_state, save_dir / self.training_state_file_name)
|
||||||
|
|
||||||
def save_checkpont(
|
def save_checkpoint(
|
||||||
self,
|
self,
|
||||||
train_step: int,
|
train_step: int,
|
||||||
policy: Policy,
|
policy: Policy,
|
||||||
|
|||||||
@@ -26,10 +26,11 @@ class ACTConfig:
|
|||||||
Those are: `input_shapes` and 'output_shapes`.
|
Those are: `input_shapes` and 'output_shapes`.
|
||||||
|
|
||||||
Notes on the inputs and outputs:
|
Notes on the inputs and outputs:
|
||||||
|
- "observation.state" is required as an input key.
|
||||||
- At least one key starting with "observation.image is required as an input.
|
- At least one key starting with "observation.image is required as an input.
|
||||||
- If there are multiple keys beginning with "observation.images." they are treated as multiple camera
|
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
|
||||||
views. Right now we only support all images having the same shape.
|
views.
|
||||||
- May optionally work without an "observation.state" key for the proprioceptive robot state.
|
Right now we only support all images having the same shape.
|
||||||
- "action" is required as an output key.
|
- "action" is required as an output key.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -200,12 +200,13 @@ class ACT(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
|
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
|
||||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||||
self.use_input_state = "observation.state" in config.input_shapes
|
self.has_state = "observation.state" in config.input_shapes
|
||||||
|
self.latent_dim = config.latent_dim
|
||||||
if self.config.use_vae:
|
if self.config.use_vae:
|
||||||
self.vae_encoder = ACTEncoder(config)
|
self.vae_encoder = ACTEncoder(config)
|
||||||
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
|
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
|
||||||
# Projection layer for joint-space configuration to hidden dimension.
|
# Projection layer for joint-space configuration to hidden dimension.
|
||||||
if self.use_input_state:
|
if self.has_state:
|
||||||
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
||||||
config.input_shapes["observation.state"][0], config.dim_model
|
config.input_shapes["observation.state"][0], config.dim_model
|
||||||
)
|
)
|
||||||
@@ -217,9 +218,7 @@ class ACT(nn.Module):
|
|||||||
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
|
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
|
||||||
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
|
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
|
||||||
# dimension.
|
# dimension.
|
||||||
num_input_token_encoder = 1 + config.chunk_size
|
num_input_token_encoder = 1 + 1 + config.chunk_size if self.has_state else 1 + config.chunk_size
|
||||||
if self.use_input_state:
|
|
||||||
num_input_token_encoder += 1
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"vae_encoder_pos_enc",
|
"vae_encoder_pos_enc",
|
||||||
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
|
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
|
||||||
@@ -242,16 +241,16 @@ class ACT(nn.Module):
|
|||||||
|
|
||||||
# Transformer encoder input projections. The tokens will be structured like
|
# Transformer encoder input projections. The tokens will be structured like
|
||||||
# [latent, robot_state, image_feature_map_pixels].
|
# [latent, robot_state, image_feature_map_pixels].
|
||||||
if self.use_input_state:
|
if self.has_state:
|
||||||
self.encoder_robot_state_input_proj = nn.Linear(
|
self.encoder_robot_state_input_proj = nn.Linear(
|
||||||
config.input_shapes["observation.state"][0], config.dim_model
|
config.input_shapes["observation.state"][0], config.dim_model
|
||||||
)
|
)
|
||||||
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
|
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model)
|
||||||
self.encoder_img_feat_input_proj = nn.Conv2d(
|
self.encoder_img_feat_input_proj = nn.Conv2d(
|
||||||
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
||||||
)
|
)
|
||||||
# Transformer encoder positional embeddings.
|
# Transformer encoder positional embeddings.
|
||||||
num_input_token_decoder = 2 if self.use_input_state else 1
|
num_input_token_decoder = 2 if self.has_state else 1
|
||||||
self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model)
|
self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model)
|
||||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||||
|
|
||||||
@@ -299,12 +298,12 @@ class ACT(nn.Module):
|
|||||||
cls_embed = einops.repeat(
|
cls_embed = einops.repeat(
|
||||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||||
) # (B, 1, D)
|
) # (B, 1, D)
|
||||||
if self.use_input_state:
|
if self.has_state:
|
||||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
||||||
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
||||||
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||||
|
|
||||||
if self.use_input_state:
|
if self.has_state:
|
||||||
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
||||||
else:
|
else:
|
||||||
vae_encoder_input = [cls_embed, action_embed]
|
vae_encoder_input = [cls_embed, action_embed]
|
||||||
@@ -329,7 +328,7 @@ class ACT(nn.Module):
|
|||||||
# When not using the VAE encoder, we set the latent to be all zeros.
|
# When not using the VAE encoder, we set the latent to be all zeros.
|
||||||
mu = log_sigma_x2 = None
|
mu = log_sigma_x2 = None
|
||||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||||
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
|
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
|
||||||
batch["observation.state"].device
|
batch["observation.state"].device
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -351,12 +350,12 @@ class ACT(nn.Module):
|
|||||||
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
|
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
|
||||||
|
|
||||||
# Get positional embeddings for robot state and latent.
|
# Get positional embeddings for robot state and latent.
|
||||||
if self.use_input_state:
|
if self.has_state:
|
||||||
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
|
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
|
||||||
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
|
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
|
||||||
|
|
||||||
# Stack encoder input and positional embeddings moving to (S, B, C).
|
# Stack encoder input and positional embeddings moving to (S, B, C).
|
||||||
encoder_in_feats = [latent_embed, robot_state_embed] if self.use_input_state else [latent_embed]
|
encoder_in_feats = [latent_embed, robot_state_embed] if self.has_state else [latent_embed]
|
||||||
encoder_in = torch.cat(
|
encoder_in = torch.cat(
|
||||||
[
|
[
|
||||||
torch.stack(encoder_in_feats, axis=0),
|
torch.stack(encoder_in_feats, axis=0),
|
||||||
|
|||||||
@@ -28,7 +28,10 @@ class DiffusionConfig:
|
|||||||
|
|
||||||
Notes on the inputs and outputs:
|
Notes on the inputs and outputs:
|
||||||
- "observation.state" is required as an input key.
|
- "observation.state" is required as an input key.
|
||||||
- A key starting with "observation.image is required as an input.
|
- At least one key starting with "observation.image is required as an input.
|
||||||
|
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
|
||||||
|
views.
|
||||||
|
Right now we only support all images having the same shape.
|
||||||
- "action" is required as an output key.
|
- "action" is required as an output key.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@
|
|||||||
# ```bash
|
# ```bash
|
||||||
# python lerobot/scripts/train.py \
|
# python lerobot/scripts/train.py \
|
||||||
# policy=act_real \
|
# policy=act_real \
|
||||||
# env=dora_aloha_real
|
# env=aloha_real
|
||||||
# ```
|
# ```
|
||||||
|
|
||||||
seed: 1000
|
seed: 1000
|
||||||
|
|||||||
@@ -9,7 +9,7 @@
|
|||||||
# ```bash
|
# ```bash
|
||||||
# python lerobot/scripts/train.py \
|
# python lerobot/scripts/train.py \
|
||||||
# policy=act_real_no_state \
|
# policy=act_real_no_state \
|
||||||
# env=dora_aloha_real
|
# env=aloha_real
|
||||||
# ```
|
# ```
|
||||||
|
|
||||||
seed: 1000
|
seed: 1000
|
||||||
|
|||||||
@@ -164,7 +164,10 @@ def rollout(
|
|||||||
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
|
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
|
||||||
# available of none of the envs finished.
|
# available of none of the envs finished.
|
||||||
if "final_info" in info:
|
if "final_info" in info:
|
||||||
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
|
successes = [
|
||||||
|
info["is_success"] if info is not None and "is_success" in info else False
|
||||||
|
for info in info["final_info"]
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
successes = [False] * env.num_envs
|
successes = [False] * env.num_envs
|
||||||
|
|
||||||
|
|||||||
@@ -345,7 +345,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||||||
logging.info(f"Checkpoint policy after step {step}")
|
logging.info(f"Checkpoint policy after step {step}")
|
||||||
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
||||||
# needed (choose 6 as a minimum for consistency without being overkill).
|
# needed (choose 6 as a minimum for consistency without being overkill).
|
||||||
logger.save_checkpont(
|
logger.save_checkpoint(
|
||||||
step,
|
step,
|
||||||
policy,
|
policy,
|
||||||
optimizer,
|
optimizer,
|
||||||
|
|||||||
2
poetry.lock
generated
2
poetry.lock
generated
@@ -1104,7 +1104,7 @@ pyarrow = ">=12.0.0"
|
|||||||
type = "git"
|
type = "git"
|
||||||
url = "https://github.com/dora-rs/dora-lerobot.git"
|
url = "https://github.com/dora-rs/dora-lerobot.git"
|
||||||
reference = "HEAD"
|
reference = "HEAD"
|
||||||
resolved_reference = "ed0c00a4fdc6ec856c9842551acd7dc7ee776f79"
|
resolved_reference = "1c6c2a401c3a2967d41444be6286ca9a28893abf"
|
||||||
subdirectory = "gym_dora"
|
subdirectory = "gym_dora"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:2fff6294b94cf42d4dd1249dcc5c3b0269d6d9c697f894e61b867d7ab81a94e4
|
oid sha256:ebd21273f6048b66c806f92035352843a9069908b3296863fd55d34cf71cd0ef
|
||||||
size 5104
|
size 51248
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:4aa23e51607604a18b70fa42edbbe1af34f119d985628fc27cc1bbb0efbc8901
|
oid sha256:b9bbf951891077320a5da27e77ddb580a6e833e8d3162b62a2f887a1989585cc
|
||||||
size 31688
|
size 31688
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:6fd368406c93cb562a69ff11cf7adf34a4b223507dcb2b9e9b8f44ee1036988a
|
oid sha256:d4070bd1f1cd8c72bc2daf628088e42b8ef113f6df0bfd9e91be052bc90038c3
|
||||||
size 68
|
size 68
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:5663ee79a13bb70a1604b887dd21bf89d18482287442419c6cc6c5bf0e753e99
|
oid sha256:42f92239223bb4df32d5c3016bc67450159f1285a7ab046307b645f699ccc34e
|
||||||
size 34928
|
size 34928
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:fb1a45463efd860af2ca22c16c77d55a18bd96fef080ae77978845a2f22ef716
|
oid sha256:52f85d6262ad1dd0b66578b25829fed96aaaca3c7458cb73ac75111350d17fcf
|
||||||
size 5104
|
size 51248
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:aa5a43e22f01d8e2f8d19f31753608794f1edbd74aaf71660091ab80ea58dc9b
|
oid sha256:5ba7c910618f0f3ca69f82f3d70c880d2b2e432456524a2a63dfd5c50efa45f0
|
||||||
size 30808
|
size 30808
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:54d1f75cf67a7b1d7a7c6865ecb9b1cc86a2f032d1890245f8996789ab6e0df6
|
oid sha256:53ad410f43855254438790f54aa7c895a052776acdd922906ae430684f659b53
|
||||||
size 33608
|
size 33608
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
|
|||||||
batch = next(iter(dataloader))
|
batch = next(iter(dataloader))
|
||||||
obs = {}
|
obs = {}
|
||||||
for k in batch:
|
for k in batch:
|
||||||
if k.startswith("observation"):
|
if "observation" in k:
|
||||||
obs[k] = batch[k]
|
obs[k] = batch[k]
|
||||||
|
|
||||||
if "n_action_steps" in cfg.policy:
|
if "n_action_steps" in cfg.policy:
|
||||||
@@ -115,8 +115,8 @@ if __name__ == "__main__":
|
|||||||
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||||
),
|
),
|
||||||
("aloha", "act", ["policy.n_action_steps=10"]),
|
("aloha", "act", ["policy.n_action_steps=10"]),
|
||||||
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
|
("dora_aloha_real", "act_real", []),
|
||||||
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
|
("dora_aloha_real", "act_real_no_state", []),
|
||||||
]
|
]
|
||||||
for env, policy, extra_overrides in env_policies:
|
for env, policy, extra_overrides in env_policies:
|
||||||
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)
|
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)
|
||||||
|
|||||||
Reference in New Issue
Block a user