forked from tangger/lerobot
[HIL-SERL]Remove overstrict pre-commit modifications (#1028)
This commit is contained in:
@@ -32,11 +32,7 @@ import numpy as np
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import PIL
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
from skimage.metrics import (
|
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
|
||||||
mean_squared_error,
|
|
||||||
peak_signal_noise_ratio,
|
|
||||||
structural_similarity,
|
|
||||||
)
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
@@ -98,11 +94,7 @@ def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> t
|
|||||||
|
|
||||||
|
|
||||||
def save_decoded_frames(
|
def save_decoded_frames(
|
||||||
imgs_dir: Path,
|
imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int
|
||||||
save_dir: Path,
|
|
||||||
frames: torch.Tensor,
|
|
||||||
timestamps: list[float],
|
|
||||||
fps: int,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if save_dir.exists() and len(list(save_dir.glob("frame_*.png"))) == len(timestamps):
|
if save_dir.exists() and len(list(save_dir.glob("frame_*.png"))) == len(timestamps):
|
||||||
return
|
return
|
||||||
@@ -112,10 +104,7 @@ def save_decoded_frames(
|
|||||||
idx = int(ts * fps)
|
idx = int(ts * fps)
|
||||||
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
|
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
|
||||||
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame_{idx:06d}_decoded.png")
|
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame_{idx:06d}_decoded.png")
|
||||||
shutil.copyfile(
|
shutil.copyfile(imgs_dir / f"frame_{idx:06d}.png", save_dir / f"frame_{idx:06d}_original.png")
|
||||||
imgs_dir / f"frame_{idx:06d}.png",
|
|
||||||
save_dir / f"frame_{idx:06d}_original.png",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||||
@@ -131,11 +120,7 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
|||||||
imgs_dataset = hf_dataset.select_columns(img_keys[0])
|
imgs_dataset = hf_dataset.select_columns(img_keys[0])
|
||||||
|
|
||||||
for i, item in enumerate(
|
for i, item in enumerate(
|
||||||
tqdm(
|
tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False)
|
||||||
imgs_dataset,
|
|
||||||
desc=f"saving {dataset.repo_id} first episode images",
|
|
||||||
leave=False,
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
img = item[img_keys[0]]
|
img = item[img_keys[0]]
|
||||||
img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100)
|
img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100)
|
||||||
@@ -290,9 +275,7 @@ def benchmark_encoding_decoding(
|
|||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
benchmark_table = []
|
benchmark_table = []
|
||||||
for timestamps_mode in tqdm(
|
for timestamps_mode in tqdm(
|
||||||
decoding_cfg["timestamps_modes"],
|
decoding_cfg["timestamps_modes"], desc="decodings (timestamps_modes)", leave=False
|
||||||
desc="decodings (timestamps_modes)",
|
|
||||||
leave=False,
|
|
||||||
):
|
):
|
||||||
for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False):
|
for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False):
|
||||||
benchmark_row = benchmark_decoding(
|
benchmark_row = benchmark_decoding(
|
||||||
|
|||||||
@@ -32,10 +32,7 @@ import torch
|
|||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
|
|
||||||
import lerobot
|
import lerobot
|
||||||
from lerobot.common.datasets.lerobot_dataset import (
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||||
LeRobotDataset,
|
|
||||||
LeRobotDatasetMetadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
# We ported a number of existing datasets ourselves, use this to see the list:
|
# We ported a number of existing datasets ourselves, use this to see the list:
|
||||||
print("List of available datasets:")
|
print("List of available datasets:")
|
||||||
|
|||||||
@@ -22,10 +22,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import (
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||||
LeRobotDataset,
|
|
||||||
LeRobotDatasetMetadata,
|
|
||||||
)
|
|
||||||
from lerobot.common.datasets.utils import dataset_to_policy_features
|
from lerobot.common.datasets.utils import dataset_to_policy_features
|
||||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||||
@@ -80,24 +77,7 @@ def main():
|
|||||||
# Load the previous action (-0.1), the next action to be executed (0.0),
|
# Load the previous action (-0.1), the next action to be executed (0.0),
|
||||||
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
||||||
# used to supervise the policy.
|
# used to supervise the policy.
|
||||||
"action": [
|
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||||
-0.1,
|
|
||||||
0.0,
|
|
||||||
0.1,
|
|
||||||
0.2,
|
|
||||||
0.3,
|
|
||||||
0.4,
|
|
||||||
0.5,
|
|
||||||
0.6,
|
|
||||||
0.7,
|
|
||||||
0.8,
|
|
||||||
0.9,
|
|
||||||
1.0,
|
|
||||||
1.1,
|
|
||||||
1.2,
|
|
||||||
1.3,
|
|
||||||
1.4,
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# We can then instantiate the dataset with these delta_timestamps configuration.
|
# We can then instantiate the dataset with these delta_timestamps configuration.
|
||||||
|
|||||||
@@ -26,10 +26,7 @@ import math
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import (
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||||
LeRobotDataset,
|
|
||||||
LeRobotDatasetMetadata,
|
|
||||||
)
|
|
||||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||||
|
|
||||||
|
|
||||||
@@ -54,24 +51,7 @@ def main():
|
|||||||
# Load the previous action (-0.1), the next action to be executed (0.0),
|
# Load the previous action (-0.1), the next action to be executed (0.0),
|
||||||
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
||||||
# used to calculate the loss.
|
# used to calculate the loss.
|
||||||
"action": [
|
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||||
-0.1,
|
|
||||||
0.0,
|
|
||||||
0.1,
|
|
||||||
0.2,
|
|
||||||
0.3,
|
|
||||||
0.4,
|
|
||||||
0.5,
|
|
||||||
0.6,
|
|
||||||
0.7,
|
|
||||||
0.8,
|
|
||||||
0.9,
|
|
||||||
1.0,
|
|
||||||
1.1,
|
|
||||||
1.2,
|
|
||||||
1.3,
|
|
||||||
1.4,
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Load the last 10% of episodes of the dataset as a validation set.
|
# Load the last 10% of episodes of the dataset as a validation set.
|
||||||
|
|||||||
@@ -19,10 +19,7 @@ from lerobot.common.datasets.utils import load_image_as_numpy
|
|||||||
|
|
||||||
|
|
||||||
def estimate_num_samples(
|
def estimate_num_samples(
|
||||||
dataset_len: int,
|
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
|
||||||
min_num_samples: int = 100,
|
|
||||||
max_num_samples: int = 10_000,
|
|
||||||
power: float = 0.75,
|
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Heuristic to estimate the number of samples based on dataset size.
|
"""Heuristic to estimate the number of samples based on dataset size.
|
||||||
The power controls the sample growth relative to dataset size.
|
The power controls the sample growth relative to dataset size.
|
||||||
@@ -126,9 +123,7 @@ def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
|||||||
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
|
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
|
||||||
|
|
||||||
|
|
||||||
def aggregate_feature_stats(
|
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||||
stats_ft_list: list[dict[str, dict]],
|
|
||||||
) -> dict[str, dict[str, np.ndarray]]:
|
|
||||||
"""Aggregates stats for a single feature."""
|
"""Aggregates stats for a single feature."""
|
||||||
means = np.stack([s["mean"] for s in stats_ft_list])
|
means = np.stack([s["mean"] for s in stats_ft_list])
|
||||||
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
|
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
|
||||||
@@ -157,9 +152,7 @@ def aggregate_feature_stats(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def aggregate_stats(
|
def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||||
stats_list: list[dict[str, dict]],
|
|
||||||
) -> dict[str, dict[str, np.ndarray]]:
|
|
||||||
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
|
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
|
||||||
|
|
||||||
The final stats will have the union of all data keys from each of the stats dicts.
|
The final stats will have the union of all data keys from each of the stats dicts.
|
||||||
|
|||||||
@@ -154,32 +154,14 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
|||||||
OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()},
|
OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()},
|
||||||
# Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied
|
# Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied
|
||||||
# with real data rather than the dummy initialization.
|
# with real data rather than the dummy initialization.
|
||||||
OnlineBuffer.OCCUPANCY_MASK_KEY: {
|
OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)},
|
||||||
"dtype": np.dtype("?"),
|
OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||||
"shape": (buffer_capacity,),
|
OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||||
},
|
OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||||
OnlineBuffer.INDEX_KEY: {
|
OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)},
|
||||||
"dtype": np.dtype("int64"),
|
|
||||||
"shape": (buffer_capacity,),
|
|
||||||
},
|
|
||||||
OnlineBuffer.FRAME_INDEX_KEY: {
|
|
||||||
"dtype": np.dtype("int64"),
|
|
||||||
"shape": (buffer_capacity,),
|
|
||||||
},
|
|
||||||
OnlineBuffer.EPISODE_INDEX_KEY: {
|
|
||||||
"dtype": np.dtype("int64"),
|
|
||||||
"shape": (buffer_capacity,),
|
|
||||||
},
|
|
||||||
OnlineBuffer.TIMESTAMP_KEY: {
|
|
||||||
"dtype": np.dtype("float64"),
|
|
||||||
"shape": (buffer_capacity,),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
for k, v in data_spec.items():
|
for k, v in data_spec.items():
|
||||||
complete_data_spec[k] = {
|
complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])}
|
||||||
"dtype": v["dtype"],
|
|
||||||
"shape": (buffer_capacity, *v["shape"]),
|
|
||||||
}
|
|
||||||
return complete_data_spec
|
return complete_data_spec
|
||||||
|
|
||||||
def add_data(self, data: dict[str, np.ndarray]):
|
def add_data(self, data: dict[str, np.ndarray]):
|
||||||
|
|||||||
@@ -77,9 +77,7 @@ def check_repo_id(repo_id: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
# TODO(aliberts): remove
|
# TODO(aliberts): remove
|
||||||
def calculate_episode_data_index(
|
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
|
||||||
hf_dataset: datasets.Dataset,
|
|
||||||
) -> Dict[str, torch.Tensor]:
|
|
||||||
"""
|
"""
|
||||||
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
|
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
|
||||||
|
|
||||||
|
|||||||
@@ -43,10 +43,7 @@ class EpisodeAwareSampler:
|
|||||||
):
|
):
|
||||||
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
||||||
indices.extend(
|
indices.extend(
|
||||||
range(
|
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
|
||||||
start_index.item() + drop_n_first_frames,
|
|
||||||
end_index.item() - drop_n_last_frames,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.indices = indices
|
self.indices = indices
|
||||||
|
|||||||
@@ -118,10 +118,7 @@ DATASETS = {
|
|||||||
"single_task": "Place the battery into the slot of the remote controller.",
|
"single_task": "Place the battery into the slot of the remote controller.",
|
||||||
**ALOHA_STATIC_INFO,
|
**ALOHA_STATIC_INFO,
|
||||||
},
|
},
|
||||||
"aloha_static_candy": {
|
"aloha_static_candy": {"single_task": "Pick up the candy and unwrap it.", **ALOHA_STATIC_INFO},
|
||||||
"single_task": "Pick up the candy and unwrap it.",
|
|
||||||
**ALOHA_STATIC_INFO,
|
|
||||||
},
|
|
||||||
"aloha_static_coffee": {
|
"aloha_static_coffee": {
|
||||||
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.",
|
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.",
|
||||||
**ALOHA_STATIC_INFO,
|
**ALOHA_STATIC_INFO,
|
||||||
@@ -170,22 +167,13 @@ DATASETS = {
|
|||||||
"single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
|
"single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
|
||||||
**ALOHA_STATIC_INFO,
|
**ALOHA_STATIC_INFO,
|
||||||
},
|
},
|
||||||
"aloha_static_ziploc_slide": {
|
"aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO},
|
||||||
"single_task": "Slide open the ziploc bag.",
|
"aloha_sim_insertion_scripted": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
|
||||||
**ALOHA_STATIC_INFO,
|
|
||||||
},
|
|
||||||
"aloha_sim_insertion_scripted": {
|
|
||||||
"single_task": "Insert the peg into the socket.",
|
|
||||||
**ALOHA_STATIC_INFO,
|
|
||||||
},
|
|
||||||
"aloha_sim_insertion_scripted_image": {
|
"aloha_sim_insertion_scripted_image": {
|
||||||
"single_task": "Insert the peg into the socket.",
|
"single_task": "Insert the peg into the socket.",
|
||||||
**ALOHA_STATIC_INFO,
|
**ALOHA_STATIC_INFO,
|
||||||
},
|
},
|
||||||
"aloha_sim_insertion_human": {
|
"aloha_sim_insertion_human": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
|
||||||
"single_task": "Insert the peg into the socket.",
|
|
||||||
**ALOHA_STATIC_INFO,
|
|
||||||
},
|
|
||||||
"aloha_sim_insertion_human_image": {
|
"aloha_sim_insertion_human_image": {
|
||||||
"single_task": "Insert the peg into the socket.",
|
"single_task": "Insert the peg into the socket.",
|
||||||
**ALOHA_STATIC_INFO,
|
**ALOHA_STATIC_INFO,
|
||||||
@@ -206,19 +194,10 @@ DATASETS = {
|
|||||||
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
|
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
|
||||||
**ALOHA_STATIC_INFO,
|
**ALOHA_STATIC_INFO,
|
||||||
},
|
},
|
||||||
"pusht": {
|
"pusht": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
|
||||||
"single_task": "Push the T-shaped block onto the T-shaped target.",
|
"pusht_image": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
|
||||||
**PUSHT_INFO,
|
|
||||||
},
|
|
||||||
"pusht_image": {
|
|
||||||
"single_task": "Push the T-shaped block onto the T-shaped target.",
|
|
||||||
**PUSHT_INFO,
|
|
||||||
},
|
|
||||||
"unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO},
|
"unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO},
|
||||||
"unitreeh1_rearrange_objects": {
|
"unitreeh1_rearrange_objects": {"single_task": "Put the object into the bin.", **UNITREEH_INFO},
|
||||||
"single_task": "Put the object into the bin.",
|
|
||||||
**UNITREEH_INFO,
|
|
||||||
},
|
|
||||||
"unitreeh1_two_robot_greeting": {
|
"unitreeh1_two_robot_greeting": {
|
||||||
"single_task": "Greet the other robot with a high five.",
|
"single_task": "Greet the other robot with a high five.",
|
||||||
**UNITREEH_INFO,
|
**UNITREEH_INFO,
|
||||||
@@ -228,31 +207,13 @@ DATASETS = {
|
|||||||
**UNITREEH_INFO,
|
**UNITREEH_INFO,
|
||||||
},
|
},
|
||||||
"xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
"xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||||
"xarm_lift_medium_image": {
|
"xarm_lift_medium_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||||
"single_task": "Pick up the cube and lift it.",
|
"xarm_lift_medium_replay": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||||
**XARM_INFO,
|
"xarm_lift_medium_replay_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||||
},
|
|
||||||
"xarm_lift_medium_replay": {
|
|
||||||
"single_task": "Pick up the cube and lift it.",
|
|
||||||
**XARM_INFO,
|
|
||||||
},
|
|
||||||
"xarm_lift_medium_replay_image": {
|
|
||||||
"single_task": "Pick up the cube and lift it.",
|
|
||||||
**XARM_INFO,
|
|
||||||
},
|
|
||||||
"xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
"xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||||
"xarm_push_medium_image": {
|
"xarm_push_medium_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||||
"single_task": "Push the cube onto the target.",
|
"xarm_push_medium_replay": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||||
**XARM_INFO,
|
"xarm_push_medium_replay_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||||
},
|
|
||||||
"xarm_push_medium_replay": {
|
|
||||||
"single_task": "Push the cube onto the target.",
|
|
||||||
**XARM_INFO,
|
|
||||||
},
|
|
||||||
"xarm_push_medium_replay_image": {
|
|
||||||
"single_task": "Push the cube onto the target.",
|
|
||||||
**XARM_INFO,
|
|
||||||
},
|
|
||||||
"umi_cup_in_the_wild": {
|
"umi_cup_in_the_wild": {
|
||||||
"single_task": "Put the cup on the plate.",
|
"single_task": "Put the cup on the plate.",
|
||||||
"license": "apache-2.0",
|
"license": "apache-2.0",
|
||||||
|
|||||||
@@ -379,12 +379,7 @@ def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]
|
|||||||
for i in range(0, len(lfs_untracked_videos), 100):
|
for i in range(0, len(lfs_untracked_videos), 100):
|
||||||
files = lfs_untracked_videos[i : i + 100]
|
files = lfs_untracked_videos[i : i + 100]
|
||||||
try:
|
try:
|
||||||
subprocess.run(
|
subprocess.run(["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True)
|
||||||
["git", "rm", "--cached", *files],
|
|
||||||
cwd=work_dir,
|
|
||||||
capture_output=True,
|
|
||||||
check=True,
|
|
||||||
)
|
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
print("git rm --cached ERROR:")
|
print("git rm --cached ERROR:")
|
||||||
print(e.stderr)
|
print(e.stderr)
|
||||||
@@ -407,17 +402,7 @@ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
|
|||||||
repo_url = f"https://huggingface.co/datasets/{repo_id}"
|
repo_url = f"https://huggingface.co/datasets/{repo_id}"
|
||||||
env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files
|
env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files
|
||||||
subprocess.run(
|
subprocess.run(
|
||||||
[
|
["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)],
|
||||||
"git",
|
|
||||||
"clone",
|
|
||||||
"--branch",
|
|
||||||
branch,
|
|
||||||
"--single-branch",
|
|
||||||
"--depth",
|
|
||||||
"1",
|
|
||||||
repo_url,
|
|
||||||
str(work_dir),
|
|
||||||
],
|
|
||||||
check=True,
|
check=True,
|
||||||
env=env,
|
env=env,
|
||||||
)
|
)
|
||||||
@@ -425,11 +410,7 @@ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
|
|||||||
|
|
||||||
def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]:
|
def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]:
|
||||||
lfs_tracked_files = subprocess.run(
|
lfs_tracked_files = subprocess.run(
|
||||||
["git", "lfs", "ls-files", "-n"],
|
["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True
|
||||||
cwd=work_dir,
|
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
check=True,
|
|
||||||
)
|
)
|
||||||
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
|
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
|
||||||
return [f for f in video_files if f not in lfs_tracked_files]
|
return [f for f in video_files if f not in lfs_tracked_files]
|
||||||
@@ -443,11 +424,7 @@ def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch
|
|||||||
]
|
]
|
||||||
hub_api = HfApi()
|
hub_api = HfApi()
|
||||||
hub_api.snapshot_download(
|
hub_api.snapshot_download(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
|
||||||
repo_type="dataset",
|
|
||||||
local_dir=local_dir,
|
|
||||||
revision=branch,
|
|
||||||
allow_patterns=video_files,
|
|
||||||
)
|
)
|
||||||
videos_info_dict = {}
|
videos_info_dict = {}
|
||||||
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
|
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
|
||||||
@@ -474,11 +451,7 @@ def convert_dataset(
|
|||||||
|
|
||||||
hub_api = HfApi()
|
hub_api = HfApi()
|
||||||
hub_api.snapshot_download(
|
hub_api.snapshot_download(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/"
|
||||||
repo_type="dataset",
|
|
||||||
revision=v1,
|
|
||||||
local_dir=v1x_dir,
|
|
||||||
ignore_patterns="videos*/",
|
|
||||||
)
|
)
|
||||||
branch = "main"
|
branch = "main"
|
||||||
if test_branch:
|
if test_branch:
|
||||||
@@ -536,21 +509,12 @@ def convert_dataset(
|
|||||||
dataset = dataset.remove_columns(video_keys)
|
dataset = dataset.remove_columns(video_keys)
|
||||||
clean_gitattr = Path(
|
clean_gitattr = Path(
|
||||||
hub_api.hf_hub_download(
|
hub_api.hf_hub_download(
|
||||||
repo_id=GITATTRIBUTES_REF,
|
repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes"
|
||||||
repo_type="dataset",
|
|
||||||
local_dir=local_dir,
|
|
||||||
filename=".gitattributes",
|
|
||||||
)
|
)
|
||||||
).absolute()
|
).absolute()
|
||||||
with tempfile.TemporaryDirectory() as tmp_video_dir:
|
with tempfile.TemporaryDirectory() as tmp_video_dir:
|
||||||
move_videos(
|
move_videos(
|
||||||
repo_id,
|
repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch
|
||||||
video_keys,
|
|
||||||
total_episodes,
|
|
||||||
total_chunks,
|
|
||||||
Path(tmp_video_dir),
|
|
||||||
clean_gitattr,
|
|
||||||
branch,
|
|
||||||
)
|
)
|
||||||
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
|
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
|
||||||
for key in video_keys:
|
for key in video_keys:
|
||||||
@@ -579,11 +543,7 @@ def convert_dataset(
|
|||||||
|
|
||||||
# Episodes
|
# Episodes
|
||||||
episodes = [
|
episodes = [
|
||||||
{
|
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
|
||||||
"episode_index": ep_idx,
|
|
||||||
"tasks": tasks_by_episodes[ep_idx],
|
|
||||||
"length": episode_lengths[ep_idx],
|
|
||||||
}
|
|
||||||
for ep_idx in episode_indices
|
for ep_idx in episode_indices
|
||||||
]
|
]
|
||||||
write_jsonlines(episodes, v20_dir / EPISODES_PATH)
|
write_jsonlines(episodes, v20_dir / EPISODES_PATH)
|
||||||
@@ -612,12 +572,7 @@ def convert_dataset(
|
|||||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
|
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
|
||||||
|
|
||||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||||
hub_api.delete_folder(
|
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
|
||||||
repo_id=repo_id,
|
|
||||||
path_in_repo="meta_data",
|
|
||||||
repo_type="dataset",
|
|
||||||
revision=branch,
|
|
||||||
)
|
|
||||||
|
|
||||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
|
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
|
||||||
|
|||||||
@@ -37,16 +37,8 @@ import logging
|
|||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
|
||||||
EPISODES_STATS_PATH,
|
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
|
||||||
STATS_PATH,
|
|
||||||
load_stats,
|
|
||||||
write_info,
|
|
||||||
)
|
|
||||||
from lerobot.common.datasets.v21.convert_stats import (
|
|
||||||
check_aggregate_stats,
|
|
||||||
convert_stats,
|
|
||||||
)
|
|
||||||
|
|
||||||
V20 = "v2.0"
|
V20 = "v2.0"
|
||||||
V21 = "v2.1"
|
V21 = "v2.1"
|
||||||
@@ -87,16 +79,10 @@ def convert_dataset(
|
|||||||
|
|
||||||
hub_api = HfApi()
|
hub_api = HfApi()
|
||||||
if hub_api.file_exists(
|
if hub_api.file_exists(
|
||||||
repo_id=dataset.repo_id,
|
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
|
||||||
filename=STATS_PATH,
|
|
||||||
revision=branch,
|
|
||||||
repo_type="dataset",
|
|
||||||
):
|
):
|
||||||
hub_api.delete_file(
|
hub_api.delete_file(
|
||||||
path_in_repo=STATS_PATH,
|
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
|
||||||
repo_id=dataset.repo_id,
|
|
||||||
revision=branch,
|
|
||||||
repo_type="dataset",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||||
|
|||||||
@@ -17,11 +17,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from lerobot.common.datasets.compute_stats import (
|
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
||||||
aggregate_stats,
|
|
||||||
get_feature_stats,
|
|
||||||
sample_indices,
|
|
||||||
)
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.datasets.utils import write_episode_stats
|
from lerobot.common.datasets.utils import write_episode_stats
|
||||||
|
|
||||||
@@ -99,9 +95,5 @@ def check_aggregate_stats(
|
|||||||
if key in reference_stats and stat in reference_stats[key]:
|
if key in reference_stats and stat in reference_stats[key]:
|
||||||
err_msg = f"feature='{key}' stats='{stat}'"
|
err_msg = f"feature='{key}' stats='{stat}'"
|
||||||
np.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
val,
|
val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
|
||||||
reference_stats[key][stat],
|
|
||||||
rtol=rtol,
|
|
||||||
atol=atol,
|
|
||||||
err_msg=err_msg,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -49,11 +49,7 @@ class DiffuserSchedulerConfig(LRSchedulerConfig):
|
|||||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer}
|
||||||
**asdict(self),
|
|
||||||
"num_training_steps": num_training_steps,
|
|
||||||
"optimizer": optimizer,
|
|
||||||
}
|
|
||||||
return get_scheduler(**kwargs)
|
return get_scheduler(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
@@ -75,10 +71,7 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
|
|||||||
progress = float(adjusted_step - self.num_warmup_steps) / float(
|
progress = float(adjusted_step - self.num_warmup_steps) / float(
|
||||||
max(1, num_training_steps - self.num_warmup_steps)
|
max(1, num_training_steps - self.num_warmup_steps)
|
||||||
)
|
)
|
||||||
return max(
|
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))
|
||||||
0.0,
|
|
||||||
0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)),
|
|
||||||
)
|
|
||||||
|
|
||||||
return LambdaLR(optimizer, lr_lambda, -1)
|
return LambdaLR(optimizer, lr_lambda, -1)
|
||||||
|
|
||||||
|
|||||||
@@ -241,9 +241,7 @@ class ACTTemporalEnsembler:
|
|||||||
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
|
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
|
||||||
# operations later.
|
# operations later.
|
||||||
self.ensembled_actions_count = torch.ones(
|
self.ensembled_actions_count = torch.ones(
|
||||||
(self.chunk_size, 1),
|
(self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
|
||||||
dtype=torch.long,
|
|
||||||
device=self.ensembled_actions.device,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
|
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
|
||||||
@@ -255,10 +253,7 @@ class ACTTemporalEnsembler:
|
|||||||
# The last action, which has no prior online average, needs to get concatenated onto the end.
|
# The last action, which has no prior online average, needs to get concatenated onto the end.
|
||||||
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
|
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
|
||||||
self.ensembled_actions_count = torch.cat(
|
self.ensembled_actions_count = torch.cat(
|
||||||
[
|
[self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])]
|
||||||
self.ensembled_actions_count,
|
|
||||||
torch.ones_like(self.ensembled_actions_count[-1:]),
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
# "Consume" the first action.
|
# "Consume" the first action.
|
||||||
action, self.ensembled_actions, self.ensembled_actions_count = (
|
action, self.ensembled_actions, self.ensembled_actions_count = (
|
||||||
@@ -338,11 +333,7 @@ class ACT(nn.Module):
|
|||||||
# Backbone for image feature extraction.
|
# Backbone for image feature extraction.
|
||||||
if self.config.image_features:
|
if self.config.image_features:
|
||||||
backbone_model = getattr(torchvision.models, config.vision_backbone)(
|
backbone_model = getattr(torchvision.models, config.vision_backbone)(
|
||||||
replace_stride_with_dilation=[
|
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
|
||||||
False,
|
|
||||||
False,
|
|
||||||
config.replace_final_stride_with_dilation,
|
|
||||||
],
|
|
||||||
weights=config.pretrained_backbone_weights,
|
weights=config.pretrained_backbone_weights,
|
||||||
norm_layer=FrozenBatchNorm2d,
|
norm_layer=FrozenBatchNorm2d,
|
||||||
)
|
)
|
||||||
@@ -436,11 +427,7 @@ class ACT(nn.Module):
|
|||||||
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||||
|
|
||||||
if self.config.robot_state_feature:
|
if self.config.robot_state_feature:
|
||||||
vae_encoder_input = [
|
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
||||||
cls_embed,
|
|
||||||
robot_state_embed,
|
|
||||||
action_embed,
|
|
||||||
] # (B, S+2, D)
|
|
||||||
else:
|
else:
|
||||||
vae_encoder_input = [cls_embed, action_embed]
|
vae_encoder_input = [cls_embed, action_embed]
|
||||||
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
|
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
|
||||||
@@ -553,10 +540,7 @@ class ACTEncoder(nn.Module):
|
|||||||
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
|
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
|
||||||
x: Tensor,
|
|
||||||
pos_embed: Tensor | None = None,
|
|
||||||
key_padding_mask: Tensor | None = None,
|
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
|
x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
|
||||||
@@ -619,10 +603,7 @@ class ACTDecoder(nn.Module):
|
|||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
x = layer(
|
x = layer(
|
||||||
x,
|
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
|
||||||
encoder_out,
|
|
||||||
decoder_pos_embed=decoder_pos_embed,
|
|
||||||
encoder_pos_embed=encoder_pos_embed,
|
|
||||||
)
|
)
|
||||||
if self.norm is not None:
|
if self.norm is not None:
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
|
|||||||
@@ -209,10 +209,7 @@ class DiffusionModel(nn.Module):
|
|||||||
|
|
||||||
# ========= inference ============
|
# ========= inference ============
|
||||||
def conditional_sample(
|
def conditional_sample(
|
||||||
self,
|
self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None
|
||||||
batch_size: int,
|
|
||||||
global_cond: Tensor | None = None,
|
|
||||||
generator: torch.Generator | None = None,
|
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
device = get_device_from_parameters(self)
|
device = get_device_from_parameters(self)
|
||||||
dtype = get_dtype_from_parameters(self)
|
dtype = get_dtype_from_parameters(self)
|
||||||
@@ -257,10 +254,7 @@ class DiffusionModel(nn.Module):
|
|||||||
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the
|
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the
|
||||||
# feature dim (effectively concatenating the camera features).
|
# feature dim (effectively concatenating the camera features).
|
||||||
img_features = einops.rearrange(
|
img_features = einops.rearrange(
|
||||||
img_features_list,
|
img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||||
"(n b s) ... -> b s (n ...)",
|
|
||||||
b=batch_size,
|
|
||||||
s=n_obs_steps,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
|
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
|
||||||
@@ -270,10 +264,7 @@ class DiffusionModel(nn.Module):
|
|||||||
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
||||||
# feature dim (effectively concatenating the camera features).
|
# feature dim (effectively concatenating the camera features).
|
||||||
img_features = einops.rearrange(
|
img_features = einops.rearrange(
|
||||||
img_features,
|
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||||
"(b s n) ... -> b s (n ...)",
|
|
||||||
b=batch_size,
|
|
||||||
s=n_obs_steps,
|
|
||||||
)
|
)
|
||||||
global_cond_feats.append(img_features)
|
global_cond_feats.append(img_features)
|
||||||
|
|
||||||
@@ -524,9 +515,7 @@ class DiffusionRgbEncoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def _replace_submodules(
|
def _replace_submodules(
|
||||||
root_module: nn.Module,
|
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
||||||
predicate: Callable[[nn.Module], bool],
|
|
||||||
func: Callable[[nn.Module], nn.Module],
|
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -644,14 +633,10 @@ class DiffusionConditionalUnet1d(nn.Module):
|
|||||||
self.mid_modules = nn.ModuleList(
|
self.mid_modules = nn.ModuleList(
|
||||||
[
|
[
|
||||||
DiffusionConditionalResidualBlock1d(
|
DiffusionConditionalResidualBlock1d(
|
||||||
config.down_dims[-1],
|
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
|
||||||
config.down_dims[-1],
|
|
||||||
**common_res_block_kwargs,
|
|
||||||
),
|
),
|
||||||
DiffusionConditionalResidualBlock1d(
|
DiffusionConditionalResidualBlock1d(
|
||||||
config.down_dims[-1],
|
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
|
||||||
config.down_dims[-1],
|
|
||||||
**common_res_block_kwargs,
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -61,11 +61,7 @@ from lerobot.common.policies.pi0.conversion_scripts.conversion_utils import (
|
|||||||
)
|
)
|
||||||
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
||||||
|
|
||||||
PRECISIONS = {
|
PRECISIONS = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
|
||||||
"bfloat16": torch.bfloat16,
|
|
||||||
"float32": torch.float32,
|
|
||||||
"float16": torch.float16,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def slice_paligemma_state_dict(state_dict, config):
|
def slice_paligemma_state_dict(state_dict, config):
|
||||||
|
|||||||
@@ -48,32 +48,18 @@ def flex_attention_forward(
|
|||||||
|
|
||||||
key_states = key_states[:, :, :, None, :]
|
key_states = key_states[:, :, :, None, :]
|
||||||
key_states = key_states.expand(
|
key_states = key_states.expand(
|
||||||
batch_size,
|
batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
|
||||||
key_states.shape[1],
|
|
||||||
num_key_value_heads,
|
|
||||||
num_key_value_groups,
|
|
||||||
head_dim,
|
|
||||||
)
|
)
|
||||||
key_states = key_states.reshape(
|
key_states = key_states.reshape(
|
||||||
batch_size,
|
batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
|
||||||
key_states.shape[1],
|
|
||||||
num_key_value_heads * num_key_value_groups,
|
|
||||||
head_dim,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
value_states = value_states[:, :, :, None, :]
|
value_states = value_states[:, :, :, None, :]
|
||||||
value_states = value_states.expand(
|
value_states = value_states.expand(
|
||||||
batch_size,
|
batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
|
||||||
value_states.shape[1],
|
|
||||||
num_key_value_heads,
|
|
||||||
num_key_value_groups,
|
|
||||||
head_dim,
|
|
||||||
)
|
)
|
||||||
value_states = value_states.reshape(
|
value_states = value_states.reshape(
|
||||||
batch_size,
|
batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
|
||||||
value_states.shape[1],
|
|
||||||
num_key_value_heads * num_key_value_groups,
|
|
||||||
head_dim,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
query_states = query_states.transpose(1, 2)
|
query_states = query_states.transpose(1, 2)
|
||||||
|
|||||||
@@ -69,11 +69,7 @@ from lerobot.common.utils.utils import get_safe_dtype
|
|||||||
|
|
||||||
|
|
||||||
def create_sinusoidal_pos_embedding(
|
def create_sinusoidal_pos_embedding(
|
||||||
time: torch.tensor,
|
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||||
dimension: int,
|
|
||||||
min_period: float,
|
|
||||||
max_period: float,
|
|
||||||
device="cpu",
|
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||||
if dimension % 2 != 0:
|
if dimension % 2 != 0:
|
||||||
@@ -581,11 +577,7 @@ class PI0FlowMatching(nn.Module):
|
|||||||
|
|
||||||
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
||||||
time_emb = create_sinusoidal_pos_embedding(
|
time_emb = create_sinusoidal_pos_embedding(
|
||||||
timestep,
|
timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device
|
||||||
self.config.proj_width,
|
|
||||||
min_period=4e-3,
|
|
||||||
max_period=4.0,
|
|
||||||
device=device,
|
|
||||||
)
|
)
|
||||||
time_emb = time_emb.type(dtype=dtype)
|
time_emb = time_emb.type(dtype=dtype)
|
||||||
|
|
||||||
@@ -617,15 +609,7 @@ class PI0FlowMatching(nn.Module):
|
|||||||
return embs, pad_masks, att_masks
|
return embs, pad_masks, att_masks
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
|
||||||
images,
|
|
||||||
img_masks,
|
|
||||||
lang_tokens,
|
|
||||||
lang_masks,
|
|
||||||
state,
|
|
||||||
actions,
|
|
||||||
noise=None,
|
|
||||||
time=None,
|
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
|
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
|
||||||
if noise is None:
|
if noise is None:
|
||||||
@@ -671,11 +655,7 @@ class PI0FlowMatching(nn.Module):
|
|||||||
device = state.device
|
device = state.device
|
||||||
|
|
||||||
if noise is None:
|
if noise is None:
|
||||||
actions_shape = (
|
actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim)
|
||||||
bsize,
|
|
||||||
self.config.n_action_steps,
|
|
||||||
self.config.max_action_dim,
|
|
||||||
)
|
|
||||||
noise = self.sample_noise(actions_shape, device)
|
noise = self.sample_noise(actions_shape, device)
|
||||||
|
|
||||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||||
|
|||||||
@@ -293,18 +293,12 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
|||||||
# in `transformers`. (molbap)
|
# in `transformers`. (molbap)
|
||||||
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
|
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
|
||||||
value_states = torch.cat(
|
value_states = torch.cat(
|
||||||
[past_key_values[layer_idx]["value_states"], value_states],
|
[past_key_values[layer_idx]["value_states"], value_states], dim=1
|
||||||
dim=1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
attention_interface = self.get_attention_interface()
|
attention_interface = self.get_attention_interface()
|
||||||
att_output = attention_interface(
|
att_output = attention_interface(
|
||||||
attention_mask,
|
attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||||
batch_size,
|
|
||||||
head_dim,
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
)
|
)
|
||||||
att_output = att_output.to(dtype=torch.bfloat16)
|
att_output = att_output.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
@@ -364,24 +358,12 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
|||||||
return attention_interface
|
return attention_interface
|
||||||
|
|
||||||
def flash_attention_forward(
|
def flash_attention_forward(
|
||||||
self,
|
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||||
attention_mask,
|
|
||||||
batch_size,
|
|
||||||
head_dim,
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
):
|
):
|
||||||
raise NotImplementedError("FA2 is not implemented (yet)")
|
raise NotImplementedError("FA2 is not implemented (yet)")
|
||||||
|
|
||||||
def eager_attention_forward(
|
def eager_attention_forward(
|
||||||
self,
|
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||||
attention_mask,
|
|
||||||
batch_size,
|
|
||||||
head_dim,
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
):
|
):
|
||||||
num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
|
num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
|
||||||
num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
|
num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
|
||||||
@@ -393,31 +375,17 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
|||||||
sequence_length = key_states.shape[1]
|
sequence_length = key_states.shape[1]
|
||||||
|
|
||||||
key_states = key_states[:, :, :, None, :].expand(
|
key_states = key_states[:, :, :, None, :].expand(
|
||||||
batch_size,
|
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||||
sequence_length,
|
|
||||||
num_key_value_heads,
|
|
||||||
num_key_value_groups,
|
|
||||||
head_dim,
|
|
||||||
)
|
)
|
||||||
key_states = key_states.reshape(
|
key_states = key_states.reshape(
|
||||||
batch_size,
|
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||||
sequence_length,
|
|
||||||
num_key_value_heads * num_key_value_groups,
|
|
||||||
head_dim,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
value_states = value_states[:, :, :, None, :].expand(
|
value_states = value_states[:, :, :, None, :].expand(
|
||||||
batch_size,
|
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||||
sequence_length,
|
|
||||||
num_key_value_heads,
|
|
||||||
num_key_value_groups,
|
|
||||||
head_dim,
|
|
||||||
)
|
)
|
||||||
value_states = value_states.reshape(
|
value_states = value_states.reshape(
|
||||||
batch_size,
|
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||||
sequence_length,
|
|
||||||
num_key_value_heads * num_key_value_groups,
|
|
||||||
head_dim,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Attention here is upcasted to float32 to match the original eager implementation.
|
# Attention here is upcasted to float32 to match the original eager implementation.
|
||||||
|
|||||||
@@ -39,11 +39,7 @@ from lerobot.common.constants import OBS_ENV, OBS_ROBOT
|
|||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||||
from lerobot.common.policies.utils import (
|
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||||
get_device_from_parameters,
|
|
||||||
get_output_shape,
|
|
||||||
populate_queues,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TDMPCPolicy(PreTrainedPolicy):
|
class TDMPCPolicy(PreTrainedPolicy):
|
||||||
@@ -67,11 +63,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
|||||||
config_class = TDMPCConfig
|
config_class = TDMPCConfig
|
||||||
name = "tdmpc"
|
name = "tdmpc"
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
|
||||||
self,
|
|
||||||
config: TDMPCConfig,
|
|
||||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||||
@@ -197,20 +189,13 @@ class TDMPCPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
|
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
|
||||||
# trajectories.
|
# trajectories.
|
||||||
z = einops.repeat(
|
z = einops.repeat(z, "b d -> n b d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
|
||||||
z,
|
|
||||||
"b d -> n b d",
|
|
||||||
n=self.config.n_gaussian_samples + self.config.n_pi_samples,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization
|
# Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization
|
||||||
# algorithm.
|
# algorithm.
|
||||||
# The initial mean and standard deviation for the cross-entropy method (CEM).
|
# The initial mean and standard deviation for the cross-entropy method (CEM).
|
||||||
mean = torch.zeros(
|
mean = torch.zeros(
|
||||||
self.config.horizon,
|
self.config.horizon, batch_size, self.config.action_feature.shape[0], device=device
|
||||||
batch_size,
|
|
||||||
self.config.action_feature.shape[0],
|
|
||||||
device=device,
|
|
||||||
)
|
)
|
||||||
# Maybe warm start CEM with the mean from the previous step.
|
# Maybe warm start CEM with the mean from the previous step.
|
||||||
if self._prev_mean is not None:
|
if self._prev_mean is not None:
|
||||||
@@ -306,10 +291,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
|||||||
if self.config.q_ensemble_size > 2:
|
if self.config.q_ensemble_size > 2:
|
||||||
G += (
|
G += (
|
||||||
running_discount
|
running_discount
|
||||||
* torch.min(
|
* torch.min(terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], dim=0)[
|
||||||
terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))],
|
0
|
||||||
dim=0,
|
]
|
||||||
)[0]
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
G += running_discount * torch.min(terminal_values, dim=0)[0]
|
G += running_discount * torch.min(terminal_values, dim=0)[0]
|
||||||
@@ -345,10 +329,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
|||||||
# Apply random image augmentations.
|
# Apply random image augmentations.
|
||||||
if self.config.image_features and self.config.max_random_shift_ratio > 0:
|
if self.config.image_features and self.config.max_random_shift_ratio > 0:
|
||||||
observations["observation.image"] = flatten_forward_unflatten(
|
observations["observation.image"] = flatten_forward_unflatten(
|
||||||
partial(
|
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
|
||||||
random_shifts_aug,
|
|
||||||
max_random_shift_ratio=self.config.max_random_shift_ratio,
|
|
||||||
),
|
|
||||||
observations["observation.image"],
|
observations["observation.image"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -572,10 +553,7 @@ class TDMPCTOLD(nn.Module):
|
|||||||
self._Qs = nn.ModuleList(
|
self._Qs = nn.ModuleList(
|
||||||
[
|
[
|
||||||
nn.Sequential(
|
nn.Sequential(
|
||||||
nn.Linear(
|
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
|
||||||
config.latent_dim + config.action_feature.shape[0],
|
|
||||||
config.mlp_dim,
|
|
||||||
),
|
|
||||||
nn.LayerNorm(config.mlp_dim),
|
nn.LayerNorm(config.mlp_dim),
|
||||||
nn.Tanh(),
|
nn.Tanh(),
|
||||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||||
@@ -724,26 +702,11 @@ class TDMPCObservationEncoder(nn.Module):
|
|||||||
stride=2,
|
stride=2,
|
||||||
),
|
),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Conv2d(
|
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
|
||||||
config.image_encoder_hidden_dim,
|
|
||||||
config.image_encoder_hidden_dim,
|
|
||||||
5,
|
|
||||||
stride=2,
|
|
||||||
),
|
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Conv2d(
|
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
||||||
config.image_encoder_hidden_dim,
|
|
||||||
config.image_encoder_hidden_dim,
|
|
||||||
3,
|
|
||||||
stride=2,
|
|
||||||
),
|
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Conv2d(
|
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
||||||
config.image_encoder_hidden_dim,
|
|
||||||
config.image_encoder_hidden_dim,
|
|
||||||
3,
|
|
||||||
stride=2,
|
|
||||||
),
|
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
)
|
)
|
||||||
dummy_shape = (1, *next(iter(config.image_features.values())).shape)
|
dummy_shape = (1, *next(iter(config.image_features.values())).shape)
|
||||||
@@ -786,8 +749,7 @@ class TDMPCObservationEncoder(nn.Module):
|
|||||||
if self.config.image_features:
|
if self.config.image_features:
|
||||||
feat.append(
|
feat.append(
|
||||||
flatten_forward_unflatten(
|
flatten_forward_unflatten(
|
||||||
self.image_enc_layers,
|
self.image_enc_layers, obs_dict[next(iter(self.config.image_features))]
|
||||||
obs_dict[next(iter(self.config.image_features))],
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if self.config.env_state_feature:
|
if self.config.env_state_feature:
|
||||||
@@ -834,9 +796,7 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
|
|||||||
"""Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param."""
|
"""Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param."""
|
||||||
for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True):
|
for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True):
|
||||||
for (n_p_ema, p_ema), (n_p, p) in zip(
|
for (n_p_ema, p_ema), (n_p, p) in zip(
|
||||||
ema_module.named_parameters(recurse=False),
|
ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True
|
||||||
module.named_parameters(recurse=False),
|
|
||||||
strict=True,
|
|
||||||
):
|
):
|
||||||
assert n_p_ema == n_p, "Parameter names don't match for EMA model update"
|
assert n_p_ema == n_p, "Parameter names don't match for EMA model update"
|
||||||
if isinstance(p, dict):
|
if isinstance(p, dict):
|
||||||
|
|||||||
@@ -193,12 +193,7 @@ class VQBeTConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def action_delta_indices(self) -> list:
|
def action_delta_indices(self) -> list:
|
||||||
return list(
|
return list(range(1 - self.n_obs_steps, self.n_action_pred_token + self.action_chunk_size - 1))
|
||||||
range(
|
|
||||||
1 - self.n_obs_steps,
|
|
||||||
self.n_action_pred_token + self.action_chunk_size - 1,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def reward_delta_indices(self) -> None:
|
def reward_delta_indices(self) -> None:
|
||||||
|
|||||||
@@ -29,11 +29,7 @@ from torch import Tensor, nn
|
|||||||
|
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.common.policies.utils import (
|
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||||
get_device_from_parameters,
|
|
||||||
get_output_shape,
|
|
||||||
populate_queues,
|
|
||||||
)
|
|
||||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||||
from lerobot.common.policies.vqbet.vqbet_utils import GPT, ResidualVQ
|
from lerobot.common.policies.vqbet.vqbet_utils import GPT, ResidualVQ
|
||||||
|
|
||||||
@@ -328,8 +324,7 @@ class VQBeTModel(nn.Module):
|
|||||||
|
|
||||||
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
|
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
|
||||||
self.state_projector = MLP(
|
self.state_projector = MLP(
|
||||||
config.robot_state_feature.shape[0],
|
config.robot_state_feature.shape[0], hidden_channels=[self.config.gpt_input_dim]
|
||||||
hidden_channels=[self.config.gpt_input_dim],
|
|
||||||
)
|
)
|
||||||
self.rgb_feature_projector = MLP(
|
self.rgb_feature_projector = MLP(
|
||||||
self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]
|
self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]
|
||||||
@@ -359,11 +354,7 @@ class VQBeTModel(nn.Module):
|
|||||||
)
|
)
|
||||||
# Separate batch and sequence dims.
|
# Separate batch and sequence dims.
|
||||||
img_features = einops.rearrange(
|
img_features = einops.rearrange(
|
||||||
img_features,
|
img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images
|
||||||
"(b s n) ... -> b s n ...",
|
|
||||||
b=batch_size,
|
|
||||||
s=n_obs_steps,
|
|
||||||
n=self.num_images,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Arrange prior and current observation step tokens as shown in the class docstring.
|
# Arrange prior and current observation step tokens as shown in the class docstring.
|
||||||
@@ -400,11 +391,7 @@ class VQBeTModel(nn.Module):
|
|||||||
# Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional).
|
# Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional).
|
||||||
if len_additional_action_token > 0:
|
if len_additional_action_token > 0:
|
||||||
features = torch.cat(
|
features = torch.cat(
|
||||||
[
|
[features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1
|
||||||
features[:, historical_act_pred_index],
|
|
||||||
features[:, -len_additional_action_token:],
|
|
||||||
],
|
|
||||||
dim=1,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
features = features[:, historical_act_pred_index]
|
features = features[:, historical_act_pred_index]
|
||||||
@@ -527,13 +514,7 @@ class VQBeTHead(nn.Module):
|
|||||||
|
|
||||||
cbet_secondary_logits = self.map_to_cbet_preds_secondary_bin(
|
cbet_secondary_logits = self.map_to_cbet_preds_secondary_bin(
|
||||||
torch.cat(
|
torch.cat(
|
||||||
(
|
(x, F.one_hot(sampled_primary_centers, num_classes=self.config.vqvae_n_embed)),
|
||||||
x,
|
|
||||||
F.one_hot(
|
|
||||||
sampled_primary_centers,
|
|
||||||
num_classes=self.config.vqvae_n_embed,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
axis=1,
|
axis=1,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -551,9 +532,7 @@ class VQBeTHead(nn.Module):
|
|||||||
else:
|
else:
|
||||||
cbet_logits = self.map_to_cbet_preds_bin(x)
|
cbet_logits = self.map_to_cbet_preds_bin(x)
|
||||||
cbet_logits = einops.rearrange(
|
cbet_logits = einops.rearrange(
|
||||||
cbet_logits,
|
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers
|
||||||
"(NT) (G C) -> (NT) G C",
|
|
||||||
G=self.vqvae_model.vqvae_num_layers,
|
|
||||||
)
|
)
|
||||||
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
|
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
|
||||||
NT, G, choices = cbet_probs.shape
|
NT, G, choices = cbet_probs.shape
|
||||||
@@ -751,9 +730,7 @@ class VQBeTRgbEncoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def _replace_submodules(
|
def _replace_submodules(
|
||||||
root_module: nn.Module,
|
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
||||||
predicate: Callable[[nn.Module], bool],
|
|
||||||
func: Callable[[nn.Module], nn.Module],
|
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -377,10 +377,7 @@ class ResidualVQ(nn.Module):
|
|||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
VectorQuantize(
|
VectorQuantize(
|
||||||
dim=codebook_dim,
|
dim=codebook_dim, codebook_dim=codebook_dim, accept_image_fmap=accept_image_fmap, **kwargs
|
||||||
codebook_dim=codebook_dim,
|
|
||||||
accept_image_fmap=accept_image_fmap,
|
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
for _ in range(num_quantizers)
|
for _ in range(num_quantizers)
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -297,11 +297,7 @@ class IntelRealSenseCamera:
|
|||||||
if self.fps and self.capture_width and self.capture_height:
|
if self.fps and self.capture_width and self.capture_height:
|
||||||
# TODO(rcadene): can we set rgb8 directly?
|
# TODO(rcadene): can we set rgb8 directly?
|
||||||
config.enable_stream(
|
config.enable_stream(
|
||||||
rs.stream.color,
|
rs.stream.color, self.capture_width, self.capture_height, rs.format.rgb8, self.fps
|
||||||
self.capture_width,
|
|
||||||
self.capture_height,
|
|
||||||
rs.format.rgb8,
|
|
||||||
self.fps,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
config.enable_stream(rs.stream.color)
|
config.enable_stream(rs.stream.color)
|
||||||
@@ -309,11 +305,7 @@ class IntelRealSenseCamera:
|
|||||||
if self.use_depth:
|
if self.use_depth:
|
||||||
if self.fps and self.capture_width and self.capture_height:
|
if self.fps and self.capture_width and self.capture_height:
|
||||||
config.enable_stream(
|
config.enable_stream(
|
||||||
rs.stream.depth,
|
rs.stream.depth, self.capture_width, self.capture_height, rs.format.z16, self.fps
|
||||||
self.capture_width,
|
|
||||||
self.capture_height,
|
|
||||||
rs.format.z16,
|
|
||||||
self.fps,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
config.enable_stream(rs.stream.depth)
|
config.enable_stream(rs.stream.depth)
|
||||||
|
|||||||
@@ -41,9 +41,7 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> list[C
|
|||||||
cameras[key] = OpenCVCamera(cfg)
|
cameras[key] = OpenCVCamera(cfg)
|
||||||
|
|
||||||
elif cfg.type == "intelrealsense":
|
elif cfg.type == "intelrealsense":
|
||||||
from lerobot.common.robot_devices.cameras.intelrealsense import (
|
from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera
|
||||||
IntelRealSenseCamera,
|
|
||||||
)
|
|
||||||
|
|
||||||
cameras[key] = IntelRealSenseCamera(cfg)
|
cameras[key] = IntelRealSenseCamera(cfg)
|
||||||
else:
|
else:
|
||||||
@@ -60,9 +58,7 @@ def make_camera(camera_type, **kwargs) -> Camera:
|
|||||||
return OpenCVCamera(config)
|
return OpenCVCamera(config)
|
||||||
|
|
||||||
elif camera_type == "intelrealsense":
|
elif camera_type == "intelrealsense":
|
||||||
from lerobot.common.robot_devices.cameras.intelrealsense import (
|
from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera
|
||||||
IntelRealSenseCamera,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = IntelRealSenseCameraConfig(**kwargs)
|
config = IntelRealSenseCameraConfig(**kwargs)
|
||||||
return IntelRealSenseCamera(config)
|
return IntelRealSenseCamera(config)
|
||||||
|
|||||||
@@ -23,10 +23,7 @@ import numpy as np
|
|||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig
|
from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig
|
||||||
from lerobot.common.robot_devices.utils import (
|
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||||
RobotDeviceAlreadyConnectedError,
|
|
||||||
RobotDeviceNotConnectedError,
|
|
||||||
)
|
|
||||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||||
|
|
||||||
PROTOCOL_VERSION = 2.0
|
PROTOCOL_VERSION = 2.0
|
||||||
@@ -787,12 +784,7 @@ class DynamixelMotorsBus:
|
|||||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def write(
|
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
|
||||||
self,
|
|
||||||
data_name,
|
|
||||||
values: int | float | np.ndarray,
|
|
||||||
motor_names: str | list[str] | None = None,
|
|
||||||
):
|
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
raise RobotDeviceNotConnectedError(
|
raise RobotDeviceNotConnectedError(
|
||||||
f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
|
f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
|
||||||
|
|||||||
@@ -23,10 +23,7 @@ import numpy as np
|
|||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from lerobot.common.robot_devices.motors.configs import FeetechMotorsBusConfig
|
from lerobot.common.robot_devices.motors.configs import FeetechMotorsBusConfig
|
||||||
from lerobot.common.robot_devices.utils import (
|
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||||
RobotDeviceAlreadyConnectedError,
|
|
||||||
RobotDeviceNotConnectedError,
|
|
||||||
)
|
|
||||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||||
|
|
||||||
PROTOCOL_VERSION = 0
|
PROTOCOL_VERSION = 0
|
||||||
@@ -812,12 +809,7 @@ class FeetechMotorsBus:
|
|||||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def write(
|
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
|
||||||
self,
|
|
||||||
data_name,
|
|
||||||
values: int | float | np.ndarray,
|
|
||||||
motor_names: str | list[str] | None = None,
|
|
||||||
):
|
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
raise RobotDeviceNotConnectedError(
|
raise RobotDeviceNotConnectedError(
|
||||||
f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
|
f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
|
||||||
|
|||||||
@@ -30,9 +30,7 @@ class MotorsBus(Protocol):
|
|||||||
def write(self): ...
|
def write(self): ...
|
||||||
|
|
||||||
|
|
||||||
def make_motors_buses_from_configs(
|
def make_motors_buses_from_configs(motors_bus_configs: dict[str, MotorsBusConfig]) -> list[MotorsBus]:
|
||||||
motors_bus_configs: dict[str, MotorsBusConfig],
|
|
||||||
) -> list[MotorsBus]:
|
|
||||||
motors_buses = {}
|
motors_buses = {}
|
||||||
|
|
||||||
for key, cfg in motors_bus_configs.items():
|
for key, cfg in motors_bus_configs.items():
|
||||||
|
|||||||
@@ -207,10 +207,7 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
|
|||||||
|
|
||||||
print("Calibrate elbow_flex")
|
print("Calibrate elbow_flex")
|
||||||
calib["elbow_flex"] = move_to_calibrate(
|
calib["elbow_flex"] = move_to_calibrate(
|
||||||
arm,
|
arm, "elbow_flex", positive_first=False, in_between_move_hook=in_between_move_hook
|
||||||
"elbow_flex",
|
|
||||||
positive_first=False,
|
|
||||||
in_between_move_hook=in_between_move_hook,
|
|
||||||
)
|
)
|
||||||
calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024)
|
calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024)
|
||||||
|
|
||||||
@@ -242,11 +239,7 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
|
|||||||
}
|
}
|
||||||
arm.write("Goal_Position", list(positions.values()), list(positions.keys()))
|
arm.write("Goal_Position", list(positions.values()), list(positions.keys()))
|
||||||
|
|
||||||
arm.write(
|
arm.write("Goal_Position", round(calib["shoulder_lift"]["zero_pos"] - 1600), "shoulder_lift")
|
||||||
"Goal_Position",
|
|
||||||
round(calib["shoulder_lift"]["zero_pos"] - 1600),
|
|
||||||
"shoulder_lift",
|
|
||||||
)
|
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
arm.write("Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex")
|
arm.write("Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex")
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
@@ -257,11 +250,7 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
|
|||||||
|
|
||||||
print("Calibrate wrist_roll")
|
print("Calibrate wrist_roll")
|
||||||
calib["wrist_roll"] = move_to_calibrate(
|
calib["wrist_roll"] = move_to_calibrate(
|
||||||
arm,
|
arm, "wrist_roll", invert_drive_mode=True, positive_first=False, while_move_hook=while_move_hook
|
||||||
"wrist_roll",
|
|
||||||
invert_drive_mode=True,
|
|
||||||
positive_first=False,
|
|
||||||
while_move_hook=while_move_hook,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"], "wrist_roll")
|
arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"], "wrist_roll")
|
||||||
|
|||||||
@@ -61,9 +61,7 @@ def calibrate_follower_arm(motors_bus, calib_dir_str):
|
|||||||
calib_dir.mkdir(parents=True, exist_ok=True)
|
calib_dir.mkdir(parents=True, exist_ok=True)
|
||||||
calib_file = calib_dir / "main_follower.json"
|
calib_file = calib_dir / "main_follower.json"
|
||||||
try:
|
try:
|
||||||
from lerobot.common.robot_devices.robots.feetech_calibration import (
|
from lerobot.common.robot_devices.robots.feetech_calibration import run_arm_manual_calibration
|
||||||
run_arm_manual_calibration,
|
|
||||||
)
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("[WARNING] Calibration function not available. Skipping calibration.")
|
print("[WARNING] Calibration function not available. Skipping calibration.")
|
||||||
return
|
return
|
||||||
@@ -118,14 +116,7 @@ def run_lekiwi(robot_config):
|
|||||||
robot = LeKiwi(motors_bus)
|
robot = LeKiwi(motors_bus)
|
||||||
|
|
||||||
# Define the expected arm motor IDs.
|
# Define the expected arm motor IDs.
|
||||||
arm_motor_ids = [
|
arm_motor_ids = ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"]
|
||||||
"shoulder_pan",
|
|
||||||
"shoulder_lift",
|
|
||||||
"elbow_flex",
|
|
||||||
"wrist_flex",
|
|
||||||
"wrist_roll",
|
|
||||||
"gripper",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Disable torque for each arm motor.
|
# Disable torque for each arm motor.
|
||||||
for motor in arm_motor_ids:
|
for motor in arm_motor_ids:
|
||||||
@@ -139,9 +130,7 @@ def run_lekiwi(robot_config):
|
|||||||
images_lock = threading.Lock()
|
images_lock = threading.Lock()
|
||||||
stop_event = threading.Event()
|
stop_event = threading.Event()
|
||||||
cam_thread = threading.Thread(
|
cam_thread = threading.Thread(
|
||||||
target=run_camera_capture,
|
target=run_camera_capture, args=(cameras, images_lock, latest_images_dict, stop_event), daemon=True
|
||||||
args=(cameras, images_lock, latest_images_dict, stop_event),
|
|
||||||
daemon=True,
|
|
||||||
)
|
)
|
||||||
cam_thread.start()
|
cam_thread.start()
|
||||||
|
|
||||||
|
|||||||
@@ -25,14 +25,9 @@ import zmq
|
|||||||
|
|
||||||
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||||
from lerobot.common.robot_devices.motors.feetech import TorqueMode
|
from lerobot.common.robot_devices.motors.feetech import TorqueMode
|
||||||
from lerobot.common.robot_devices.motors.utils import (
|
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
|
||||||
MotorsBus,
|
|
||||||
make_motors_buses_from_configs,
|
|
||||||
)
|
|
||||||
from lerobot.common.robot_devices.robots.configs import LeKiwiRobotConfig
|
from lerobot.common.robot_devices.robots.configs import LeKiwiRobotConfig
|
||||||
from lerobot.common.robot_devices.robots.feetech_calibration import (
|
from lerobot.common.robot_devices.robots.feetech_calibration import run_arm_manual_calibration
|
||||||
run_arm_manual_calibration,
|
|
||||||
)
|
|
||||||
from lerobot.common.robot_devices.robots.utils import get_arm_id
|
from lerobot.common.robot_devices.robots.utils import get_arm_id
|
||||||
from lerobot.common.robot_devices.utils import RobotDeviceNotConnectedError
|
from lerobot.common.robot_devices.utils import RobotDeviceNotConnectedError
|
||||||
|
|
||||||
@@ -329,11 +324,7 @@ class MobileManipulator:
|
|||||||
socks = dict(poller.poll(15))
|
socks = dict(poller.poll(15))
|
||||||
if self.video_socket not in socks or socks[self.video_socket] != zmq.POLLIN:
|
if self.video_socket not in socks or socks[self.video_socket] != zmq.POLLIN:
|
||||||
# No new data arrived → reuse ALL old data
|
# No new data arrived → reuse ALL old data
|
||||||
return (
|
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
||||||
self.last_frames,
|
|
||||||
self.last_present_speed,
|
|
||||||
self.last_remote_arm_state,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Drain all messages, keep only the last
|
# Drain all messages, keep only the last
|
||||||
last_msg = None
|
last_msg = None
|
||||||
@@ -346,11 +337,7 @@ class MobileManipulator:
|
|||||||
|
|
||||||
if not last_msg:
|
if not last_msg:
|
||||||
# No new message → also reuse old
|
# No new message → also reuse old
|
||||||
return (
|
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
||||||
self.last_frames,
|
|
||||||
self.last_present_speed,
|
|
||||||
self.last_remote_arm_state,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Decode only the final message
|
# Decode only the final message
|
||||||
try:
|
try:
|
||||||
@@ -388,11 +375,7 @@ class MobileManipulator:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[DEBUG] Error decoding video message: {e}")
|
print(f"[DEBUG] Error decoding video message: {e}")
|
||||||
# If decode fails, fall back to old data
|
# If decode fails, fall back to old data
|
||||||
return (
|
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
||||||
self.last_frames,
|
|
||||||
self.last_present_speed,
|
|
||||||
self.last_remote_arm_state,
|
|
||||||
)
|
|
||||||
|
|
||||||
return frames, present_speed, remote_arm_state_tensor
|
return frames, present_speed, remote_arm_state_tensor
|
||||||
|
|
||||||
@@ -478,11 +461,7 @@ class MobileManipulator:
|
|||||||
|
|
||||||
body_state = self.wheel_raw_to_body(present_speed)
|
body_state = self.wheel_raw_to_body(present_speed)
|
||||||
|
|
||||||
body_state_mm = (
|
body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s
|
||||||
body_state[0] * 1000.0,
|
|
||||||
body_state[1] * 1000.0,
|
|
||||||
body_state[2],
|
|
||||||
) # Convert x,y to mm/s
|
|
||||||
wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32)
|
wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32)
|
||||||
combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0)
|
combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0)
|
||||||
|
|
||||||
@@ -641,11 +620,7 @@ class MobileManipulator:
|
|||||||
# Convert each wheel’s angular speed (deg/s) to a raw integer.
|
# Convert each wheel’s angular speed (deg/s) to a raw integer.
|
||||||
wheel_raw = [MobileManipulator.degps_to_raw(deg) for deg in wheel_degps]
|
wheel_raw = [MobileManipulator.degps_to_raw(deg) for deg in wheel_degps]
|
||||||
|
|
||||||
return {
|
return {"left_wheel": wheel_raw[0], "back_wheel": wheel_raw[1], "right_wheel": wheel_raw[2]}
|
||||||
"left_wheel": wheel_raw[0],
|
|
||||||
"back_wheel": wheel_raw[1],
|
|
||||||
"right_wheel": wheel_raw[2],
|
|
||||||
}
|
|
||||||
|
|
||||||
def wheel_raw_to_body(
|
def wheel_raw_to_body(
|
||||||
self, wheel_raw: dict, wheel_radius: float = 0.05, base_radius: float = 0.125
|
self, wheel_raw: dict, wheel_radius: float = 0.05, base_radius: float = 0.125
|
||||||
|
|||||||
@@ -72,9 +72,7 @@ def make_robot_from_config(config: RobotConfig):
|
|||||||
|
|
||||||
return ManipulatorRobot(config)
|
return ManipulatorRobot(config)
|
||||||
elif isinstance(config, LeKiwiRobotConfig):
|
elif isinstance(config, LeKiwiRobotConfig):
|
||||||
from lerobot.common.robot_devices.robots.mobile_manipulator import (
|
from lerobot.common.robot_devices.robots.mobile_manipulator import MobileManipulator
|
||||||
MobileManipulator,
|
|
||||||
)
|
|
||||||
|
|
||||||
return MobileManipulator(config)
|
return MobileManipulator(config)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -48,8 +48,7 @@ class RobotDeviceNotConnectedError(Exception):
|
|||||||
"""Exception raised when the robot device is not connected."""
|
"""Exception raised when the robot device is not connected."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, message="This robot device is not connected. Try calling `robot_device.connect()` first."
|
||||||
message="This robot device is not connected. Try calling `robot_device.connect()` first.",
|
|
||||||
):
|
):
|
||||||
self.message = message
|
self.message = message
|
||||||
super().__init__(self.message)
|
super().__init__(self.message)
|
||||||
|
|||||||
@@ -42,11 +42,7 @@ def deserialize_python_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> Non
|
|||||||
"""
|
"""
|
||||||
Restores the rng state for `random` from a dictionary produced by `serialize_python_rng_state()`.
|
Restores the rng state for `random` from a dictionary produced by `serialize_python_rng_state()`.
|
||||||
"""
|
"""
|
||||||
py_state = (
|
py_state = (rng_state_dict["py_rng_version"].item(), tuple(rng_state_dict["py_rng_state"].tolist()), None)
|
||||||
rng_state_dict["py_rng_version"].item(),
|
|
||||||
tuple(rng_state_dict["py_rng_state"].tolist()),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
random.setstate(py_state)
|
random.setstate(py_state)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -42,10 +42,7 @@ def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> lis
|
|||||||
args = sys.argv[1:]
|
args = sys.argv[1:]
|
||||||
attr_level_args = []
|
attr_level_args = []
|
||||||
detect_string = f"--{field_name}."
|
detect_string = f"--{field_name}."
|
||||||
exclude_strings = (
|
exclude_strings = (f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=", f"--{field_name}.{PATH_KEY}=")
|
||||||
f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=",
|
|
||||||
f"--{field_name}.{PATH_KEY}=",
|
|
||||||
)
|
|
||||||
for arg in args:
|
for arg in args:
|
||||||
if arg.startswith(detect_string) and not arg.startswith(exclude_strings):
|
if arg.startswith(detect_string) and not arg.startswith(exclude_strings):
|
||||||
denested_arg = f"--{arg.removeprefix(detect_string)}"
|
denested_arg = f"--{arg.removeprefix(detect_string)}"
|
||||||
|
|||||||
@@ -26,11 +26,7 @@ from huggingface_hub.errors import HfHubHTTPError
|
|||||||
from lerobot.common.optim.optimizers import OptimizerConfig
|
from lerobot.common.optim.optimizers import OptimizerConfig
|
||||||
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
||||||
from lerobot.common.utils.hub import HubMixin
|
from lerobot.common.utils.hub import HubMixin
|
||||||
from lerobot.common.utils.utils import (
|
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
||||||
auto_select_torch_device,
|
|
||||||
is_amp_available,
|
|
||||||
is_torch_device_available,
|
|
||||||
)
|
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
|
|
||||||
# Generic variable that is either PreTrainedConfig or a subclass thereof
|
# Generic variable that is either PreTrainedConfig or a subclass thereof
|
||||||
|
|||||||
@@ -38,12 +38,7 @@ def get_motor_bus_cls(brand: str) -> tuple:
|
|||||||
FeetechMotorsBus,
|
FeetechMotorsBus,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return FeetechMotorsBusConfig, FeetechMotorsBus, MODEL_BAUDRATE_TABLE, SCS_SERIES_BAUDRATE_TABLE
|
||||||
FeetechMotorsBusConfig,
|
|
||||||
FeetechMotorsBus,
|
|
||||||
MODEL_BAUDRATE_TABLE,
|
|
||||||
SCS_SERIES_BAUDRATE_TABLE,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif brand == "dynamixel":
|
elif brand == "dynamixel":
|
||||||
from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig
|
from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig
|
||||||
@@ -53,12 +48,7 @@ def get_motor_bus_cls(brand: str) -> tuple:
|
|||||||
DynamixelMotorsBus,
|
DynamixelMotorsBus,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return DynamixelMotorsBusConfig, DynamixelMotorsBus, MODEL_BAUDRATE_TABLE, X_SERIES_BAUDRATE_TABLE
|
||||||
DynamixelMotorsBusConfig,
|
|
||||||
DynamixelMotorsBus,
|
|
||||||
MODEL_BAUDRATE_TABLE,
|
|
||||||
X_SERIES_BAUDRATE_TABLE,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -174,25 +164,12 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument("--port", type=str, required=True, help="Motors bus port (e.g. dynamixel,feetech)")
|
||||||
"--port",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Motors bus port (e.g. dynamixel,feetech)",
|
|
||||||
)
|
|
||||||
parser.add_argument("--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)")
|
parser.add_argument("--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)")
|
||||||
parser.add_argument("--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)")
|
parser.add_argument("--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)")
|
||||||
|
parser.add_argument("--ID", type=int, required=True, help="Desired ID of the current motor (e.g. 1,2,3)")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ID",
|
"--baudrate", type=int, default=1000000, help="Desired baudrate for the motor (default: 1000000)"
|
||||||
type=int,
|
|
||||||
required=True,
|
|
||||||
help="Desired ID of the current motor (e.g. 1,2,3)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--baudrate",
|
|
||||||
type=int,
|
|
||||||
default=1000000,
|
|
||||||
help="Desired baudrate for the motor (default: 1000000)",
|
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|||||||
@@ -149,11 +149,7 @@ def init_sim_calibration(robot, cfg):
|
|||||||
axis_directions = np.array(cfg.get("axis_directions", [1]))
|
axis_directions = np.array(cfg.get("axis_directions", [1]))
|
||||||
offsets = np.array(cfg.get("offsets", [0])) * np.pi
|
offsets = np.array(cfg.get("offsets", [0])) * np.pi
|
||||||
|
|
||||||
return {
|
return {"start_pos": start_pos, "axis_directions": axis_directions, "offsets": offsets}
|
||||||
"start_pos": start_pos,
|
|
||||||
"axis_directions": axis_directions,
|
|
||||||
"offsets": offsets,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def real_positions_to_sim(real_positions, axis_directions, start_pos, offsets):
|
def real_positions_to_sim(real_positions, axis_directions, start_pos, offsets):
|
||||||
@@ -203,7 +199,6 @@ def record(
|
|||||||
run_compute_stats: bool = True,
|
run_compute_stats: bool = True,
|
||||||
) -> LeRobotDataset:
|
) -> LeRobotDataset:
|
||||||
# Load pretrained policy
|
# Load pretrained policy
|
||||||
|
|
||||||
policy = None
|
policy = None
|
||||||
if pretrained_policy_name_or_path is not None:
|
if pretrained_policy_name_or_path is not None:
|
||||||
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
|
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
|
||||||
@@ -246,11 +241,7 @@ def record(
|
|||||||
shape = env.observation_space[key].shape
|
shape = env.observation_space[key].shape
|
||||||
if not key.startswith("observation.image."):
|
if not key.startswith("observation.image."):
|
||||||
key = "observation.image." + key
|
key = "observation.image." + key
|
||||||
features[key] = {
|
features[key] = {"dtype": "video", "names": ["channels", "height", "width"], "shape": shape}
|
||||||
"dtype": "video",
|
|
||||||
"names": ["channels", "height", "width"],
|
|
||||||
"shape": shape,
|
|
||||||
}
|
|
||||||
|
|
||||||
for key, obs_key in state_keys_dict.items():
|
for key, obs_key in state_keys_dict.items():
|
||||||
features[key] = {
|
features[key] = {
|
||||||
@@ -259,11 +250,7 @@ def record(
|
|||||||
"shape": env.observation_space[obs_key].shape,
|
"shape": env.observation_space[obs_key].shape,
|
||||||
}
|
}
|
||||||
|
|
||||||
features["action"] = {
|
features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None}
|
||||||
"dtype": "float32",
|
|
||||||
"shape": env.action_space.shape,
|
|
||||||
"names": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create empty dataset or load existing saved episodes
|
# Create empty dataset or load existing saved episodes
|
||||||
sanity_check_dataset_name(repo_id, policy)
|
sanity_check_dataset_name(repo_id, policy)
|
||||||
@@ -374,12 +361,7 @@ def record(
|
|||||||
|
|
||||||
|
|
||||||
def replay(
|
def replay(
|
||||||
env,
|
env, root: Path, repo_id: str, episode: int, fps: int | None = None, local_files_only: bool = True
|
||||||
root: Path,
|
|
||||||
repo_id: str,
|
|
||||||
episode: int,
|
|
||||||
fps: int | None = None,
|
|
||||||
local_files_only: bool = True,
|
|
||||||
):
|
):
|
||||||
env = env()
|
env = env()
|
||||||
|
|
||||||
@@ -426,10 +408,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
parser_record = subparsers.add_parser("record", parents=[base_parser])
|
parser_record = subparsers.add_parser("record", parents=[base_parser])
|
||||||
parser_record.add_argument(
|
parser_record.add_argument(
|
||||||
"--fps",
|
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||||
type=none_or_int,
|
|
||||||
default=None,
|
|
||||||
help="Frames per second (set to None to disable)",
|
|
||||||
)
|
)
|
||||||
parser_record.add_argument(
|
parser_record.add_argument(
|
||||||
"--root",
|
"--root",
|
||||||
@@ -507,19 +486,9 @@ if __name__ == "__main__":
|
|||||||
default=0,
|
default=0,
|
||||||
help="Resume recording on an existing dataset.",
|
help="Resume recording on an existing dataset.",
|
||||||
)
|
)
|
||||||
parser_record.add_argument(
|
|
||||||
"--assign-rewards",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
|
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
|
||||||
parser_replay.add_argument(
|
parser_replay.add_argument(
|
||||||
"--fps",
|
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||||
type=none_or_int,
|
|
||||||
default=None,
|
|
||||||
help="Frames per second (set to None to disable)",
|
|
||||||
)
|
)
|
||||||
parser_replay.add_argument(
|
parser_replay.add_argument(
|
||||||
"--root",
|
"--root",
|
||||||
|
|||||||
@@ -293,8 +293,7 @@ def eval_policy(
|
|||||||
seeds = None
|
seeds = None
|
||||||
else:
|
else:
|
||||||
seeds = range(
|
seeds = range(
|
||||||
start_seed + (batch_ix * env.num_envs),
|
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
|
||||||
start_seed + ((batch_ix + 1) * env.num_envs),
|
|
||||||
)
|
)
|
||||||
rollout_data = rollout(
|
rollout_data = rollout(
|
||||||
env,
|
env,
|
||||||
@@ -414,11 +413,7 @@ def eval_policy(
|
|||||||
|
|
||||||
|
|
||||||
def _compile_episode_data(
|
def _compile_episode_data(
|
||||||
rollout_data: dict,
|
rollout_data: dict, done_indices: Tensor, start_episode_index: int, start_data_index: int, fps: float
|
||||||
done_indices: Tensor,
|
|
||||||
start_episode_index: int,
|
|
||||||
start_data_index: int,
|
|
||||||
fps: float,
|
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Convenience function for `eval_policy(return_episode_data=True)`
|
"""Convenience function for `eval_policy(return_episode_data=True)`
|
||||||
|
|
||||||
@@ -486,10 +481,7 @@ def eval_main(cfg: EvalPipelineConfig):
|
|||||||
)
|
)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
|
||||||
with (
|
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||||
torch.no_grad(),
|
|
||||||
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
|
||||||
):
|
|
||||||
info = eval_policy(
|
info = eval_policy(
|
||||||
env,
|
env,
|
||||||
policy,
|
policy,
|
||||||
|
|||||||
@@ -196,11 +196,7 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
}
|
}
|
||||||
|
|
||||||
train_tracker = MetricsTracker(
|
train_tracker = MetricsTracker(
|
||||||
cfg.batch_size,
|
cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step
|
||||||
dataset.num_frames,
|
|
||||||
dataset.num_episodes,
|
|
||||||
train_metrics,
|
|
||||||
initial_step=step,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Start offline training on a fixed dataset")
|
logging.info("Start offline training on a fixed dataset")
|
||||||
@@ -271,11 +267,7 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
"eval_s": AverageMeter("eval_s", ":.3f"),
|
"eval_s": AverageMeter("eval_s", ":.3f"),
|
||||||
}
|
}
|
||||||
eval_tracker = MetricsTracker(
|
eval_tracker = MetricsTracker(
|
||||||
cfg.batch_size,
|
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
|
||||||
dataset.num_frames,
|
|
||||||
dataset.num_episodes,
|
|
||||||
eval_metrics,
|
|
||||||
initial_step=step,
|
|
||||||
)
|
)
|
||||||
eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
|
eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
|
||||||
eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")
|
eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")
|
||||||
|
|||||||
@@ -81,11 +81,7 @@ def run_server(
|
|||||||
static_folder: Path,
|
static_folder: Path,
|
||||||
template_folder: Path,
|
template_folder: Path,
|
||||||
):
|
):
|
||||||
app = Flask(
|
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
|
||||||
__name__,
|
|
||||||
static_folder=static_folder.resolve(),
|
|
||||||
template_folder=template_folder.resolve(),
|
|
||||||
)
|
|
||||||
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
|
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
|
||||||
|
|
||||||
@app.route("/")
|
@app.route("/")
|
||||||
@@ -201,8 +197,7 @@ def run_server(
|
|||||||
]
|
]
|
||||||
|
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl",
|
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", timeout=5
|
||||||
timeout=5,
|
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
# Split into lines and parse each line as JSON
|
# Split into lines and parse each line as JSON
|
||||||
@@ -287,8 +282,7 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
|||||||
repo_id = dataset.repo_id
|
repo_id = dataset.repo_id
|
||||||
|
|
||||||
url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
|
url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
|
||||||
episode_chunk=int(episode_index) // dataset.chunks_size,
|
episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index
|
||||||
episode_index=episode_index,
|
|
||||||
)
|
)
|
||||||
df = pd.read_parquet(url)
|
df = pd.read_parquet(url)
|
||||||
data = df[selected_columns] # Select specific columns
|
data = df[selected_columns] # Select specific columns
|
||||||
@@ -337,8 +331,7 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) ->
|
|||||||
|
|
||||||
def get_dataset_info(repo_id: str) -> IterableNamespace:
|
def get_dataset_info(repo_id: str) -> IterableNamespace:
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json",
|
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", timeout=5
|
||||||
timeout=5,
|
|
||||||
)
|
)
|
||||||
response.raise_for_status() # Raises an HTTPError for bad responses
|
response.raise_for_status() # Raises an HTTPError for bad responses
|
||||||
dataset_info = response.json()
|
dataset_info = response.json()
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ dependencies = [
|
|||||||
"rerun-sdk>=0.21.0",
|
"rerun-sdk>=0.21.0",
|
||||||
"termcolor>=2.4.0",
|
"termcolor>=2.4.0",
|
||||||
"torch>=2.2.1,<2.7",
|
"torch>=2.2.1,<2.7",
|
||||||
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
|
"torchcodec==0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
|
||||||
"torchmetrics>=1.6.0",
|
"torchmetrics>=1.6.0",
|
||||||
"torchvision>=0.21.0",
|
"torchvision>=0.21.0",
|
||||||
"transformers>=4.47.0",
|
"transformers>=4.47.0",
|
||||||
|
|||||||
@@ -59,33 +59,16 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
|
|||||||
"action": {
|
"action": {
|
||||||
"dtype": "float32",
|
"dtype": "float32",
|
||||||
"shape": (6,),
|
"shape": (6,),
|
||||||
"names": [
|
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||||
"shoulder_pan",
|
|
||||||
"shoulder_lift",
|
|
||||||
"elbow_flex",
|
|
||||||
"wrist_flex",
|
|
||||||
"wrist_roll",
|
|
||||||
"gripper",
|
|
||||||
],
|
|
||||||
},
|
},
|
||||||
"observation.state": {
|
"observation.state": {
|
||||||
"dtype": "float32",
|
"dtype": "float32",
|
||||||
"shape": (6,),
|
"shape": (6,),
|
||||||
"names": [
|
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||||
"shoulder_pan",
|
|
||||||
"shoulder_lift",
|
|
||||||
"elbow_flex",
|
|
||||||
"wrist_flex",
|
|
||||||
"wrist_roll",
|
|
||||||
"gripper",
|
|
||||||
],
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
info = info_factory(
|
info = info_factory(
|
||||||
total_episodes=1,
|
total_episodes=1, total_frames=1, camera_features=camera_features, motor_features=motor_features
|
||||||
total_frames=1,
|
|
||||||
camera_features=camera_features,
|
|
||||||
motor_features=motor_features,
|
|
||||||
)
|
)
|
||||||
ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info)
|
ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info)
|
||||||
return ds_meta
|
return ds_meta
|
||||||
@@ -98,8 +81,7 @@ def test_get_policy_and_config_classes(policy_name: str):
|
|||||||
policy_cfg = make_policy_config(policy_name)
|
policy_cfg = make_policy_config(policy_name)
|
||||||
assert policy_cls.name == policy_name
|
assert policy_cls.name == policy_name
|
||||||
assert issubclass(
|
assert issubclass(
|
||||||
policy_cfg.__class__,
|
policy_cfg.__class__, inspect.signature(policy_cls.__init__).parameters["config"].annotation
|
||||||
inspect.signature(policy_cls.__init__).parameters["config"].annotation,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -110,13 +92,7 @@ def test_get_policy_and_config_classes(policy_name: str):
|
|||||||
("lerobot/pusht", "pusht", {}, "diffusion", {}),
|
("lerobot/pusht", "pusht", {}, "diffusion", {}),
|
||||||
("lerobot/pusht", "pusht", {}, "vqbet", {}),
|
("lerobot/pusht", "pusht", {}, "vqbet", {}),
|
||||||
("lerobot/pusht", "pusht", {}, "act", {}),
|
("lerobot/pusht", "pusht", {}, "act", {}),
|
||||||
(
|
("lerobot/aloha_sim_insertion_human", "aloha", {"task": "AlohaInsertion-v0"}, "act", {}),
|
||||||
"lerobot/aloha_sim_insertion_human",
|
|
||||||
"aloha",
|
|
||||||
{"task": "AlohaInsertion-v0"},
|
|
||||||
"act",
|
|
||||||
{},
|
|
||||||
),
|
|
||||||
(
|
(
|
||||||
"lerobot/aloha_sim_insertion_scripted",
|
"lerobot/aloha_sim_insertion_scripted",
|
||||||
"aloha",
|
"aloha",
|
||||||
|
|||||||
@@ -390,8 +390,7 @@ def test_record_with_event_exit_early(tmp_path, request, robot_type, mock):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"robot_type, mock, num_image_writer_processes",
|
"robot_type, mock, num_image_writer_processes", [("koch", True, 0), ("koch", True, 1)]
|
||||||
[("koch", True, 0), ("koch", True, 1)],
|
|
||||||
)
|
)
|
||||||
@require_robot
|
@require_robot
|
||||||
def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, num_image_writer_processes):
|
def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, num_image_writer_processes):
|
||||||
|
|||||||
@@ -40,10 +40,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.common.robot_devices.robots.utils import make_robot
|
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||||
from lerobot.common.robot_devices.utils import (
|
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||||
RobotDeviceAlreadyConnectedError,
|
|
||||||
RobotDeviceNotConnectedError,
|
|
||||||
)
|
|
||||||
from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
|
from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -26,9 +26,7 @@ from lerobot import available_cameras, available_motors, available_robots
|
|||||||
from lerobot.common.robot_devices.cameras.utils import Camera
|
from lerobot.common.robot_devices.cameras.utils import Camera
|
||||||
from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device
|
from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device
|
||||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||||
from lerobot.common.robot_devices.motors.utils import (
|
from lerobot.common.robot_devices.motors.utils import make_motors_bus as make_motors_bus_device
|
||||||
make_motors_bus as make_motors_bus_device,
|
|
||||||
)
|
|
||||||
from lerobot.common.utils.import_utils import is_package_available
|
from lerobot.common.utils.import_utils import is_package_available
|
||||||
|
|
||||||
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu"
|
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu"
|
||||||
@@ -287,14 +285,7 @@ def mock_calibration_dir(calibration_dir):
|
|||||||
"start_pos": [1442, 843, 2166, 2849, 1988, 1835],
|
"start_pos": [1442, 843, 2166, 2849, 1988, 1835],
|
||||||
"end_pos": [2440, 1869, -1106, -1848, -926, 3235],
|
"end_pos": [2440, 1869, -1106, -1848, -926, 3235],
|
||||||
"calib_mode": ["DEGREE", "DEGREE", "DEGREE", "DEGREE", "DEGREE", "LINEAR"],
|
"calib_mode": ["DEGREE", "DEGREE", "DEGREE", "DEGREE", "DEGREE", "LINEAR"],
|
||||||
"motor_names": [
|
"motor_names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||||
"shoulder_pan",
|
|
||||||
"shoulder_lift",
|
|
||||||
"elbow_flex",
|
|
||||||
"wrist_flex",
|
|
||||||
"wrist_roll",
|
|
||||||
"gripper",
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
Path(str(calibration_dir)).mkdir(parents=True, exist_ok=True)
|
Path(str(calibration_dir)).mkdir(parents=True, exist_ok=True)
|
||||||
with open(calibration_dir / "main_follower.json", "w") as f:
|
with open(calibration_dir / "main_follower.json", "w") as f:
|
||||||
|
|||||||
@@ -18,10 +18,7 @@ from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_metrics():
|
def mock_metrics():
|
||||||
return {
|
return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")}
|
||||||
"loss": AverageMeter("loss", ":.3f"),
|
|
||||||
"accuracy": AverageMeter("accuracy", ":.2f"),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_average_meter_initialization():
|
def test_average_meter_initialization():
|
||||||
@@ -61,11 +58,7 @@ def test_average_meter_str():
|
|||||||
|
|
||||||
def test_metrics_tracker_initialization(mock_metrics):
|
def test_metrics_tracker_initialization(mock_metrics):
|
||||||
tracker = MetricsTracker(
|
tracker = MetricsTracker(
|
||||||
batch_size=32,
|
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics, initial_step=10
|
||||||
num_frames=1000,
|
|
||||||
num_episodes=50,
|
|
||||||
metrics=mock_metrics,
|
|
||||||
initial_step=10,
|
|
||||||
)
|
)
|
||||||
assert tracker.steps == 10
|
assert tracker.steps == 10
|
||||||
assert tracker.samples == 10 * 32
|
assert tracker.samples == 10 * 32
|
||||||
@@ -77,11 +70,7 @@ def test_metrics_tracker_initialization(mock_metrics):
|
|||||||
|
|
||||||
def test_metrics_tracker_step(mock_metrics):
|
def test_metrics_tracker_step(mock_metrics):
|
||||||
tracker = MetricsTracker(
|
tracker = MetricsTracker(
|
||||||
batch_size=32,
|
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics, initial_step=5
|
||||||
num_frames=1000,
|
|
||||||
num_episodes=50,
|
|
||||||
metrics=mock_metrics,
|
|
||||||
initial_step=5,
|
|
||||||
)
|
)
|
||||||
tracker.step()
|
tracker.step()
|
||||||
assert tracker.steps == 6
|
assert tracker.steps == 6
|
||||||
|
|||||||
Reference in New Issue
Block a user