[HIL-SERL]Remove overstrict pre-commit modifications (#1028)

This commit is contained in:
Adil Zouitine
2025-04-24 13:48:52 +02:00
committed by GitHub
parent 671ac3411f
commit c58b504a9e
47 changed files with 163 additions and 757 deletions

View File

@@ -32,11 +32,7 @@ import numpy as np
import pandas as pd import pandas as pd
import PIL import PIL
import torch import torch
from skimage.metrics import ( from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
mean_squared_error,
peak_signal_noise_ratio,
structural_similarity,
)
from tqdm import tqdm from tqdm import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
@@ -98,11 +94,7 @@ def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> t
def save_decoded_frames( def save_decoded_frames(
imgs_dir: Path, imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int
save_dir: Path,
frames: torch.Tensor,
timestamps: list[float],
fps: int,
) -> None: ) -> None:
if save_dir.exists() and len(list(save_dir.glob("frame_*.png"))) == len(timestamps): if save_dir.exists() and len(list(save_dir.glob("frame_*.png"))) == len(timestamps):
return return
@@ -112,10 +104,7 @@ def save_decoded_frames(
idx = int(ts * fps) idx = int(ts * fps)
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy() frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame_{idx:06d}_decoded.png") PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame_{idx:06d}_decoded.png")
shutil.copyfile( shutil.copyfile(imgs_dir / f"frame_{idx:06d}.png", save_dir / f"frame_{idx:06d}_original.png")
imgs_dir / f"frame_{idx:06d}.png",
save_dir / f"frame_{idx:06d}_original.png",
)
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None: def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
@@ -131,11 +120,7 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
imgs_dataset = hf_dataset.select_columns(img_keys[0]) imgs_dataset = hf_dataset.select_columns(img_keys[0])
for i, item in enumerate( for i, item in enumerate(
tqdm( tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False)
imgs_dataset,
desc=f"saving {dataset.repo_id} first episode images",
leave=False,
)
): ):
img = item[img_keys[0]] img = item[img_keys[0]]
img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100) img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100)
@@ -290,9 +275,7 @@ def benchmark_encoding_decoding(
random.seed(seed) random.seed(seed)
benchmark_table = [] benchmark_table = []
for timestamps_mode in tqdm( for timestamps_mode in tqdm(
decoding_cfg["timestamps_modes"], decoding_cfg["timestamps_modes"], desc="decodings (timestamps_modes)", leave=False
desc="decodings (timestamps_modes)",
leave=False,
): ):
for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False): for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False):
benchmark_row = benchmark_decoding( benchmark_row = benchmark_decoding(

View File

@@ -32,10 +32,7 @@ import torch
from huggingface_hub import HfApi from huggingface_hub import HfApi
import lerobot import lerobot
from lerobot.common.datasets.lerobot_dataset import ( from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
LeRobotDataset,
LeRobotDatasetMetadata,
)
# We ported a number of existing datasets ourselves, use this to see the list: # We ported a number of existing datasets ourselves, use this to see the list:
print("List of available datasets:") print("List of available datasets:")

View File

@@ -22,10 +22,7 @@ from pathlib import Path
import torch import torch
from lerobot.common.datasets.lerobot_dataset import ( from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
LeRobotDataset,
LeRobotDatasetMetadata,
)
from lerobot.common.datasets.utils import dataset_to_policy_features from lerobot.common.datasets.utils import dataset_to_policy_features
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
@@ -80,24 +77,7 @@ def main():
# Load the previous action (-0.1), the next action to be executed (0.0), # Load the previous action (-0.1), the next action to be executed (0.0),
# and 14 future actions with a 0.1 seconds spacing. All these actions will be # and 14 future actions with a 0.1 seconds spacing. All these actions will be
# used to supervise the policy. # used to supervise the policy.
"action": [ "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
-0.1,
0.0,
0.1,
0.2,
0.3,
0.4,
0.5,
0.6,
0.7,
0.8,
0.9,
1.0,
1.1,
1.2,
1.3,
1.4,
],
} }
# We can then instantiate the dataset with these delta_timestamps configuration. # We can then instantiate the dataset with these delta_timestamps configuration.

View File

@@ -26,10 +26,7 @@ import math
import torch import torch
from lerobot.common.datasets.lerobot_dataset import ( from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
LeRobotDataset,
LeRobotDatasetMetadata,
)
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
@@ -54,24 +51,7 @@ def main():
# Load the previous action (-0.1), the next action to be executed (0.0), # Load the previous action (-0.1), the next action to be executed (0.0),
# and 14 future actions with a 0.1 seconds spacing. All these actions will be # and 14 future actions with a 0.1 seconds spacing. All these actions will be
# used to calculate the loss. # used to calculate the loss.
"action": [ "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
-0.1,
0.0,
0.1,
0.2,
0.3,
0.4,
0.5,
0.6,
0.7,
0.8,
0.9,
1.0,
1.1,
1.2,
1.3,
1.4,
],
} }
# Load the last 10% of episodes of the dataset as a validation set. # Load the last 10% of episodes of the dataset as a validation set.

View File

@@ -19,10 +19,7 @@ from lerobot.common.datasets.utils import load_image_as_numpy
def estimate_num_samples( def estimate_num_samples(
dataset_len: int, dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
min_num_samples: int = 100,
max_num_samples: int = 10_000,
power: float = 0.75,
) -> int: ) -> int:
"""Heuristic to estimate the number of samples based on dataset size. """Heuristic to estimate the number of samples based on dataset size.
The power controls the sample growth relative to dataset size. The power controls the sample growth relative to dataset size.
@@ -126,9 +123,7 @@ def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.") raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
def aggregate_feature_stats( def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
stats_ft_list: list[dict[str, dict]],
) -> dict[str, dict[str, np.ndarray]]:
"""Aggregates stats for a single feature.""" """Aggregates stats for a single feature."""
means = np.stack([s["mean"] for s in stats_ft_list]) means = np.stack([s["mean"] for s in stats_ft_list])
variances = np.stack([s["std"] ** 2 for s in stats_ft_list]) variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
@@ -157,9 +152,7 @@ def aggregate_feature_stats(
} }
def aggregate_stats( def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
stats_list: list[dict[str, dict]],
) -> dict[str, dict[str, np.ndarray]]:
"""Aggregate stats from multiple compute_stats outputs into a single set of stats. """Aggregate stats from multiple compute_stats outputs into a single set of stats.
The final stats will have the union of all data keys from each of the stats dicts. The final stats will have the union of all data keys from each of the stats dicts.

View File

@@ -154,32 +154,14 @@ class OnlineBuffer(torch.utils.data.Dataset):
OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()}, OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()},
# Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied # Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied
# with real data rather than the dummy initialization. # with real data rather than the dummy initialization.
OnlineBuffer.OCCUPANCY_MASK_KEY: { OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)},
"dtype": np.dtype("?"), OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
"shape": (buffer_capacity,), OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
}, OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
OnlineBuffer.INDEX_KEY: { OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)},
"dtype": np.dtype("int64"),
"shape": (buffer_capacity,),
},
OnlineBuffer.FRAME_INDEX_KEY: {
"dtype": np.dtype("int64"),
"shape": (buffer_capacity,),
},
OnlineBuffer.EPISODE_INDEX_KEY: {
"dtype": np.dtype("int64"),
"shape": (buffer_capacity,),
},
OnlineBuffer.TIMESTAMP_KEY: {
"dtype": np.dtype("float64"),
"shape": (buffer_capacity,),
},
} }
for k, v in data_spec.items(): for k, v in data_spec.items():
complete_data_spec[k] = { complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])}
"dtype": v["dtype"],
"shape": (buffer_capacity, *v["shape"]),
}
return complete_data_spec return complete_data_spec
def add_data(self, data: dict[str, np.ndarray]): def add_data(self, data: dict[str, np.ndarray]):

View File

@@ -77,9 +77,7 @@ def check_repo_id(repo_id: str) -> None:
# TODO(aliberts): remove # TODO(aliberts): remove
def calculate_episode_data_index( def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
hf_dataset: datasets.Dataset,
) -> Dict[str, torch.Tensor]:
""" """
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset. Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.

View File

@@ -43,10 +43,7 @@ class EpisodeAwareSampler:
): ):
if episode_indices_to_use is None or episode_idx in episode_indices_to_use: if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
indices.extend( indices.extend(
range( range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
start_index.item() + drop_n_first_frames,
end_index.item() - drop_n_last_frames,
)
) )
self.indices = indices self.indices = indices

View File

@@ -118,10 +118,7 @@ DATASETS = {
"single_task": "Place the battery into the slot of the remote controller.", "single_task": "Place the battery into the slot of the remote controller.",
**ALOHA_STATIC_INFO, **ALOHA_STATIC_INFO,
}, },
"aloha_static_candy": { "aloha_static_candy": {"single_task": "Pick up the candy and unwrap it.", **ALOHA_STATIC_INFO},
"single_task": "Pick up the candy and unwrap it.",
**ALOHA_STATIC_INFO,
},
"aloha_static_coffee": { "aloha_static_coffee": {
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.", "single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.",
**ALOHA_STATIC_INFO, **ALOHA_STATIC_INFO,
@@ -170,22 +167,13 @@ DATASETS = {
"single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.", "single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
**ALOHA_STATIC_INFO, **ALOHA_STATIC_INFO,
}, },
"aloha_static_ziploc_slide": { "aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO},
"single_task": "Slide open the ziploc bag.", "aloha_sim_insertion_scripted": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
**ALOHA_STATIC_INFO,
},
"aloha_sim_insertion_scripted": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_insertion_scripted_image": { "aloha_sim_insertion_scripted_image": {
"single_task": "Insert the peg into the socket.", "single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO, **ALOHA_STATIC_INFO,
}, },
"aloha_sim_insertion_human": { "aloha_sim_insertion_human": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_insertion_human_image": { "aloha_sim_insertion_human_image": {
"single_task": "Insert the peg into the socket.", "single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO, **ALOHA_STATIC_INFO,
@@ -206,19 +194,10 @@ DATASETS = {
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.", "single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
**ALOHA_STATIC_INFO, **ALOHA_STATIC_INFO,
}, },
"pusht": { "pusht": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
"single_task": "Push the T-shaped block onto the T-shaped target.", "pusht_image": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
**PUSHT_INFO,
},
"pusht_image": {
"single_task": "Push the T-shaped block onto the T-shaped target.",
**PUSHT_INFO,
},
"unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO}, "unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO},
"unitreeh1_rearrange_objects": { "unitreeh1_rearrange_objects": {"single_task": "Put the object into the bin.", **UNITREEH_INFO},
"single_task": "Put the object into the bin.",
**UNITREEH_INFO,
},
"unitreeh1_two_robot_greeting": { "unitreeh1_two_robot_greeting": {
"single_task": "Greet the other robot with a high five.", "single_task": "Greet the other robot with a high five.",
**UNITREEH_INFO, **UNITREEH_INFO,
@@ -228,31 +207,13 @@ DATASETS = {
**UNITREEH_INFO, **UNITREEH_INFO,
}, },
"xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO}, "xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_lift_medium_image": { "xarm_lift_medium_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"single_task": "Pick up the cube and lift it.", "xarm_lift_medium_replay": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
**XARM_INFO, "xarm_lift_medium_replay_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
},
"xarm_lift_medium_replay": {
"single_task": "Pick up the cube and lift it.",
**XARM_INFO,
},
"xarm_lift_medium_replay_image": {
"single_task": "Pick up the cube and lift it.",
**XARM_INFO,
},
"xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO}, "xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"xarm_push_medium_image": { "xarm_push_medium_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"single_task": "Push the cube onto the target.", "xarm_push_medium_replay": {"single_task": "Push the cube onto the target.", **XARM_INFO},
**XARM_INFO, "xarm_push_medium_replay_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
},
"xarm_push_medium_replay": {
"single_task": "Push the cube onto the target.",
**XARM_INFO,
},
"xarm_push_medium_replay_image": {
"single_task": "Push the cube onto the target.",
**XARM_INFO,
},
"umi_cup_in_the_wild": { "umi_cup_in_the_wild": {
"single_task": "Put the cup on the plate.", "single_task": "Put the cup on the plate.",
"license": "apache-2.0", "license": "apache-2.0",

View File

@@ -379,12 +379,7 @@ def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]
for i in range(0, len(lfs_untracked_videos), 100): for i in range(0, len(lfs_untracked_videos), 100):
files = lfs_untracked_videos[i : i + 100] files = lfs_untracked_videos[i : i + 100]
try: try:
subprocess.run( subprocess.run(["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True)
["git", "rm", "--cached", *files],
cwd=work_dir,
capture_output=True,
check=True,
)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
print("git rm --cached ERROR:") print("git rm --cached ERROR:")
print(e.stderr) print(e.stderr)
@@ -407,17 +402,7 @@ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
repo_url = f"https://huggingface.co/datasets/{repo_id}" repo_url = f"https://huggingface.co/datasets/{repo_id}"
env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files
subprocess.run( subprocess.run(
[ ["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)],
"git",
"clone",
"--branch",
branch,
"--single-branch",
"--depth",
"1",
repo_url,
str(work_dir),
],
check=True, check=True,
env=env, env=env,
) )
@@ -425,11 +410,7 @@ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]: def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]:
lfs_tracked_files = subprocess.run( lfs_tracked_files = subprocess.run(
["git", "lfs", "ls-files", "-n"], ["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True
cwd=work_dir,
capture_output=True,
text=True,
check=True,
) )
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines()) lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
return [f for f in video_files if f not in lfs_tracked_files] return [f for f in video_files if f not in lfs_tracked_files]
@@ -443,11 +424,7 @@ def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch
] ]
hub_api = HfApi() hub_api = HfApi()
hub_api.snapshot_download( hub_api.snapshot_download(
repo_id=repo_id, repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
repo_type="dataset",
local_dir=local_dir,
revision=branch,
allow_patterns=video_files,
) )
videos_info_dict = {} videos_info_dict = {}
for vid_key, vid_path in zip(video_keys, video_files, strict=True): for vid_key, vid_path in zip(video_keys, video_files, strict=True):
@@ -474,11 +451,7 @@ def convert_dataset(
hub_api = HfApi() hub_api = HfApi()
hub_api.snapshot_download( hub_api.snapshot_download(
repo_id=repo_id, repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/"
repo_type="dataset",
revision=v1,
local_dir=v1x_dir,
ignore_patterns="videos*/",
) )
branch = "main" branch = "main"
if test_branch: if test_branch:
@@ -536,21 +509,12 @@ def convert_dataset(
dataset = dataset.remove_columns(video_keys) dataset = dataset.remove_columns(video_keys)
clean_gitattr = Path( clean_gitattr = Path(
hub_api.hf_hub_download( hub_api.hf_hub_download(
repo_id=GITATTRIBUTES_REF, repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes"
repo_type="dataset",
local_dir=local_dir,
filename=".gitattributes",
) )
).absolute() ).absolute()
with tempfile.TemporaryDirectory() as tmp_video_dir: with tempfile.TemporaryDirectory() as tmp_video_dir:
move_videos( move_videos(
repo_id, repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch
video_keys,
total_episodes,
total_chunks,
Path(tmp_video_dir),
clean_gitattr,
branch,
) )
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch) videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
for key in video_keys: for key in video_keys:
@@ -579,11 +543,7 @@ def convert_dataset(
# Episodes # Episodes
episodes = [ episodes = [
{ {"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
"episode_index": ep_idx,
"tasks": tasks_by_episodes[ep_idx],
"length": episode_lengths[ep_idx],
}
for ep_idx in episode_indices for ep_idx in episode_indices
] ]
write_jsonlines(episodes, v20_dir / EPISODES_PATH) write_jsonlines(episodes, v20_dir / EPISODES_PATH)
@@ -612,12 +572,7 @@ def convert_dataset(
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch) hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError): with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder( hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
repo_id=repo_id,
path_in_repo="meta_data",
repo_type="dataset",
revision=branch,
)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError): with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch) hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)

View File

@@ -37,16 +37,8 @@ import logging
from huggingface_hub import HfApi from huggingface_hub import HfApi
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
EPISODES_STATS_PATH, from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
STATS_PATH,
load_stats,
write_info,
)
from lerobot.common.datasets.v21.convert_stats import (
check_aggregate_stats,
convert_stats,
)
V20 = "v2.0" V20 = "v2.0"
V21 = "v2.1" V21 = "v2.1"
@@ -87,16 +79,10 @@ def convert_dataset(
hub_api = HfApi() hub_api = HfApi()
if hub_api.file_exists( if hub_api.file_exists(
repo_id=dataset.repo_id, repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
filename=STATS_PATH,
revision=branch,
repo_type="dataset",
): ):
hub_api.delete_file( hub_api.delete_file(
path_in_repo=STATS_PATH, path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
repo_id=dataset.repo_id,
revision=branch,
repo_type="dataset",
) )
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")

View File

@@ -17,11 +17,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from lerobot.common.datasets.compute_stats import ( from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
aggregate_stats,
get_feature_stats,
sample_indices,
)
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import write_episode_stats from lerobot.common.datasets.utils import write_episode_stats
@@ -99,9 +95,5 @@ def check_aggregate_stats(
if key in reference_stats and stat in reference_stats[key]: if key in reference_stats and stat in reference_stats[key]:
err_msg = f"feature='{key}' stats='{stat}'" err_msg = f"feature='{key}' stats='{stat}'"
np.testing.assert_allclose( np.testing.assert_allclose(
val, val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
reference_stats[key][stat],
rtol=rtol,
atol=atol,
err_msg=err_msg,
) )

View File

@@ -49,11 +49,7 @@ class DiffuserSchedulerConfig(LRSchedulerConfig):
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR: def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
kwargs = { kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer}
**asdict(self),
"num_training_steps": num_training_steps,
"optimizer": optimizer,
}
return get_scheduler(**kwargs) return get_scheduler(**kwargs)
@@ -75,10 +71,7 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
progress = float(adjusted_step - self.num_warmup_steps) / float( progress = float(adjusted_step - self.num_warmup_steps) / float(
max(1, num_training_steps - self.num_warmup_steps) max(1, num_training_steps - self.num_warmup_steps)
) )
return max( return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))
0.0,
0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)),
)
return LambdaLR(optimizer, lr_lambda, -1) return LambdaLR(optimizer, lr_lambda, -1)

View File

@@ -241,9 +241,7 @@ class ACTTemporalEnsembler:
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor # Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
# operations later. # operations later.
self.ensembled_actions_count = torch.ones( self.ensembled_actions_count = torch.ones(
(self.chunk_size, 1), (self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
dtype=torch.long,
device=self.ensembled_actions.device,
) )
else: else:
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute # self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
@@ -255,10 +253,7 @@ class ACTTemporalEnsembler:
# The last action, which has no prior online average, needs to get concatenated onto the end. # The last action, which has no prior online average, needs to get concatenated onto the end.
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1) self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
self.ensembled_actions_count = torch.cat( self.ensembled_actions_count = torch.cat(
[ [self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])]
self.ensembled_actions_count,
torch.ones_like(self.ensembled_actions_count[-1:]),
]
) )
# "Consume" the first action. # "Consume" the first action.
action, self.ensembled_actions, self.ensembled_actions_count = ( action, self.ensembled_actions, self.ensembled_actions_count = (
@@ -338,11 +333,7 @@ class ACT(nn.Module):
# Backbone for image feature extraction. # Backbone for image feature extraction.
if self.config.image_features: if self.config.image_features:
backbone_model = getattr(torchvision.models, config.vision_backbone)( backbone_model = getattr(torchvision.models, config.vision_backbone)(
replace_stride_with_dilation=[ replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
False,
False,
config.replace_final_stride_with_dilation,
],
weights=config.pretrained_backbone_weights, weights=config.pretrained_backbone_weights,
norm_layer=FrozenBatchNorm2d, norm_layer=FrozenBatchNorm2d,
) )
@@ -436,11 +427,7 @@ class ACT(nn.Module):
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
if self.config.robot_state_feature: if self.config.robot_state_feature:
vae_encoder_input = [ vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
cls_embed,
robot_state_embed,
action_embed,
] # (B, S+2, D)
else: else:
vae_encoder_input = [cls_embed, action_embed] vae_encoder_input = [cls_embed, action_embed]
vae_encoder_input = torch.cat(vae_encoder_input, axis=1) vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
@@ -553,10 +540,7 @@ class ACTEncoder(nn.Module):
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity() self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
def forward( def forward(
self, self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
x: Tensor,
pos_embed: Tensor | None = None,
key_padding_mask: Tensor | None = None,
) -> Tensor: ) -> Tensor:
for layer in self.layers: for layer in self.layers:
x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask) x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
@@ -619,10 +603,7 @@ class ACTDecoder(nn.Module):
) -> Tensor: ) -> Tensor:
for layer in self.layers: for layer in self.layers:
x = layer( x = layer(
x, x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
encoder_out,
decoder_pos_embed=decoder_pos_embed,
encoder_pos_embed=encoder_pos_embed,
) )
if self.norm is not None: if self.norm is not None:
x = self.norm(x) x = self.norm(x)

View File

@@ -209,10 +209,7 @@ class DiffusionModel(nn.Module):
# ========= inference ============ # ========= inference ============
def conditional_sample( def conditional_sample(
self, self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None
batch_size: int,
global_cond: Tensor | None = None,
generator: torch.Generator | None = None,
) -> Tensor: ) -> Tensor:
device = get_device_from_parameters(self) device = get_device_from_parameters(self)
dtype = get_dtype_from_parameters(self) dtype = get_dtype_from_parameters(self)
@@ -257,10 +254,7 @@ class DiffusionModel(nn.Module):
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the # Separate batch and sequence dims back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features). # feature dim (effectively concatenating the camera features).
img_features = einops.rearrange( img_features = einops.rearrange(
img_features_list, img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
"(n b s) ... -> b s (n ...)",
b=batch_size,
s=n_obs_steps,
) )
else: else:
# Combine batch, sequence, and "which camera" dims before passing to shared encoder. # Combine batch, sequence, and "which camera" dims before passing to shared encoder.
@@ -270,10 +264,7 @@ class DiffusionModel(nn.Module):
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the # Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features). # feature dim (effectively concatenating the camera features).
img_features = einops.rearrange( img_features = einops.rearrange(
img_features, img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
"(b s n) ... -> b s (n ...)",
b=batch_size,
s=n_obs_steps,
) )
global_cond_feats.append(img_features) global_cond_feats.append(img_features)
@@ -524,9 +515,7 @@ class DiffusionRgbEncoder(nn.Module):
def _replace_submodules( def _replace_submodules(
root_module: nn.Module, root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
predicate: Callable[[nn.Module], bool],
func: Callable[[nn.Module], nn.Module],
) -> nn.Module: ) -> nn.Module:
""" """
Args: Args:
@@ -644,14 +633,10 @@ class DiffusionConditionalUnet1d(nn.Module):
self.mid_modules = nn.ModuleList( self.mid_modules = nn.ModuleList(
[ [
DiffusionConditionalResidualBlock1d( DiffusionConditionalResidualBlock1d(
config.down_dims[-1], config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
config.down_dims[-1],
**common_res_block_kwargs,
), ),
DiffusionConditionalResidualBlock1d( DiffusionConditionalResidualBlock1d(
config.down_dims[-1], config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
config.down_dims[-1],
**common_res_block_kwargs,
), ),
] ]
) )

View File

@@ -61,11 +61,7 @@ from lerobot.common.policies.pi0.conversion_scripts.conversion_utils import (
) )
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
PRECISIONS = { PRECISIONS = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
"bfloat16": torch.bfloat16,
"float32": torch.float32,
"float16": torch.float16,
}
def slice_paligemma_state_dict(state_dict, config): def slice_paligemma_state_dict(state_dict, config):

View File

@@ -48,32 +48,18 @@ def flex_attention_forward(
key_states = key_states[:, :, :, None, :] key_states = key_states[:, :, :, None, :]
key_states = key_states.expand( key_states = key_states.expand(
batch_size, batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
key_states.shape[1],
num_key_value_heads,
num_key_value_groups,
head_dim,
) )
key_states = key_states.reshape( key_states = key_states.reshape(
batch_size, batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
key_states.shape[1],
num_key_value_heads * num_key_value_groups,
head_dim,
) )
value_states = value_states[:, :, :, None, :] value_states = value_states[:, :, :, None, :]
value_states = value_states.expand( value_states = value_states.expand(
batch_size, batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
value_states.shape[1],
num_key_value_heads,
num_key_value_groups,
head_dim,
) )
value_states = value_states.reshape( value_states = value_states.reshape(
batch_size, batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
value_states.shape[1],
num_key_value_heads * num_key_value_groups,
head_dim,
) )
query_states = query_states.transpose(1, 2) query_states = query_states.transpose(1, 2)

View File

@@ -69,11 +69,7 @@ from lerobot.common.utils.utils import get_safe_dtype
def create_sinusoidal_pos_embedding( def create_sinusoidal_pos_embedding(
time: torch.tensor, time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
dimension: int,
min_period: float,
max_period: float,
device="cpu",
) -> Tensor: ) -> Tensor:
"""Computes sine-cosine positional embedding vectors for scalar positions.""" """Computes sine-cosine positional embedding vectors for scalar positions."""
if dimension % 2 != 0: if dimension % 2 != 0:
@@ -581,11 +577,7 @@ class PI0FlowMatching(nn.Module):
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1] # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
time_emb = create_sinusoidal_pos_embedding( time_emb = create_sinusoidal_pos_embedding(
timestep, timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device
self.config.proj_width,
min_period=4e-3,
max_period=4.0,
device=device,
) )
time_emb = time_emb.type(dtype=dtype) time_emb = time_emb.type(dtype=dtype)
@@ -617,15 +609,7 @@ class PI0FlowMatching(nn.Module):
return embs, pad_masks, att_masks return embs, pad_masks, att_masks
def forward( def forward(
self, self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
images,
img_masks,
lang_tokens,
lang_masks,
state,
actions,
noise=None,
time=None,
) -> Tensor: ) -> Tensor:
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
if noise is None: if noise is None:
@@ -671,11 +655,7 @@ class PI0FlowMatching(nn.Module):
device = state.device device = state.device
if noise is None: if noise is None:
actions_shape = ( actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim)
bsize,
self.config.n_action_steps,
self.config.max_action_dim,
)
noise = self.sample_noise(actions_shape, device) noise = self.sample_noise(actions_shape, device)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(

View File

@@ -293,18 +293,12 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
# in `transformers`. (molbap) # in `transformers`. (molbap)
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1) key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
value_states = torch.cat( value_states = torch.cat(
[past_key_values[layer_idx]["value_states"], value_states], [past_key_values[layer_idx]["value_states"], value_states], dim=1
dim=1,
) )
attention_interface = self.get_attention_interface() attention_interface = self.get_attention_interface()
att_output = attention_interface( att_output = attention_interface(
attention_mask, attention_mask, batch_size, head_dim, query_states, key_states, value_states
batch_size,
head_dim,
query_states,
key_states,
value_states,
) )
att_output = att_output.to(dtype=torch.bfloat16) att_output = att_output.to(dtype=torch.bfloat16)
@@ -364,24 +358,12 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
return attention_interface return attention_interface
def flash_attention_forward( def flash_attention_forward(
self, self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
attention_mask,
batch_size,
head_dim,
query_states,
key_states,
value_states,
): ):
raise NotImplementedError("FA2 is not implemented (yet)") raise NotImplementedError("FA2 is not implemented (yet)")
def eager_attention_forward( def eager_attention_forward(
self, self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
attention_mask,
batch_size,
head_dim,
query_states,
key_states,
value_states,
): ):
num_att_heads = self.config.paligemma_config.text_config.num_attention_heads num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
@@ -393,31 +375,17 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
sequence_length = key_states.shape[1] sequence_length = key_states.shape[1]
key_states = key_states[:, :, :, None, :].expand( key_states = key_states[:, :, :, None, :].expand(
batch_size, batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
sequence_length,
num_key_value_heads,
num_key_value_groups,
head_dim,
) )
key_states = key_states.reshape( key_states = key_states.reshape(
batch_size, batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
sequence_length,
num_key_value_heads * num_key_value_groups,
head_dim,
) )
value_states = value_states[:, :, :, None, :].expand( value_states = value_states[:, :, :, None, :].expand(
batch_size, batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
sequence_length,
num_key_value_heads,
num_key_value_groups,
head_dim,
) )
value_states = value_states.reshape( value_states = value_states.reshape(
batch_size, batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
sequence_length,
num_key_value_heads * num_key_value_groups,
head_dim,
) )
# Attention here is upcasted to float32 to match the original eager implementation. # Attention here is upcasted to float32 to match the original eager implementation.

View File

@@ -39,11 +39,7 @@ from lerobot.common.constants import OBS_ENV, OBS_ROBOT
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.utils import ( from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
get_device_from_parameters,
get_output_shape,
populate_queues,
)
class TDMPCPolicy(PreTrainedPolicy): class TDMPCPolicy(PreTrainedPolicy):
@@ -67,11 +63,7 @@ class TDMPCPolicy(PreTrainedPolicy):
config_class = TDMPCConfig config_class = TDMPCConfig
name = "tdmpc" name = "tdmpc"
def __init__( def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
self,
config: TDMPCConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
""" """
Args: Args:
config: Policy configuration class instance or None, in which case the default instantiation of config: Policy configuration class instance or None, in which case the default instantiation of
@@ -197,20 +189,13 @@ class TDMPCPolicy(PreTrainedPolicy):
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled # In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
# trajectories. # trajectories.
z = einops.repeat( z = einops.repeat(z, "b d -> n b d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
z,
"b d -> n b d",
n=self.config.n_gaussian_samples + self.config.n_pi_samples,
)
# Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization # Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization
# algorithm. # algorithm.
# The initial mean and standard deviation for the cross-entropy method (CEM). # The initial mean and standard deviation for the cross-entropy method (CEM).
mean = torch.zeros( mean = torch.zeros(
self.config.horizon, self.config.horizon, batch_size, self.config.action_feature.shape[0], device=device
batch_size,
self.config.action_feature.shape[0],
device=device,
) )
# Maybe warm start CEM with the mean from the previous step. # Maybe warm start CEM with the mean from the previous step.
if self._prev_mean is not None: if self._prev_mean is not None:
@@ -306,10 +291,9 @@ class TDMPCPolicy(PreTrainedPolicy):
if self.config.q_ensemble_size > 2: if self.config.q_ensemble_size > 2:
G += ( G += (
running_discount running_discount
* torch.min( * torch.min(terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], dim=0)[
terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], 0
dim=0, ]
)[0]
) )
else: else:
G += running_discount * torch.min(terminal_values, dim=0)[0] G += running_discount * torch.min(terminal_values, dim=0)[0]
@@ -345,10 +329,7 @@ class TDMPCPolicy(PreTrainedPolicy):
# Apply random image augmentations. # Apply random image augmentations.
if self.config.image_features and self.config.max_random_shift_ratio > 0: if self.config.image_features and self.config.max_random_shift_ratio > 0:
observations["observation.image"] = flatten_forward_unflatten( observations["observation.image"] = flatten_forward_unflatten(
partial( partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
random_shifts_aug,
max_random_shift_ratio=self.config.max_random_shift_ratio,
),
observations["observation.image"], observations["observation.image"],
) )
@@ -572,10 +553,7 @@ class TDMPCTOLD(nn.Module):
self._Qs = nn.ModuleList( self._Qs = nn.ModuleList(
[ [
nn.Sequential( nn.Sequential(
nn.Linear( nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
config.latent_dim + config.action_feature.shape[0],
config.mlp_dim,
),
nn.LayerNorm(config.mlp_dim), nn.LayerNorm(config.mlp_dim),
nn.Tanh(), nn.Tanh(),
nn.Linear(config.mlp_dim, config.mlp_dim), nn.Linear(config.mlp_dim, config.mlp_dim),
@@ -724,26 +702,11 @@ class TDMPCObservationEncoder(nn.Module):
stride=2, stride=2,
), ),
nn.ReLU(), nn.ReLU(),
nn.Conv2d( nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
config.image_encoder_hidden_dim,
config.image_encoder_hidden_dim,
5,
stride=2,
),
nn.ReLU(), nn.ReLU(),
nn.Conv2d( nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
config.image_encoder_hidden_dim,
config.image_encoder_hidden_dim,
3,
stride=2,
),
nn.ReLU(), nn.ReLU(),
nn.Conv2d( nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
config.image_encoder_hidden_dim,
config.image_encoder_hidden_dim,
3,
stride=2,
),
nn.ReLU(), nn.ReLU(),
) )
dummy_shape = (1, *next(iter(config.image_features.values())).shape) dummy_shape = (1, *next(iter(config.image_features.values())).shape)
@@ -786,8 +749,7 @@ class TDMPCObservationEncoder(nn.Module):
if self.config.image_features: if self.config.image_features:
feat.append( feat.append(
flatten_forward_unflatten( flatten_forward_unflatten(
self.image_enc_layers, self.image_enc_layers, obs_dict[next(iter(self.config.image_features))]
obs_dict[next(iter(self.config.image_features))],
) )
) )
if self.config.env_state_feature: if self.config.env_state_feature:
@@ -834,9 +796,7 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
"""Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param.""" """Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param."""
for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True): for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True):
for (n_p_ema, p_ema), (n_p, p) in zip( for (n_p_ema, p_ema), (n_p, p) in zip(
ema_module.named_parameters(recurse=False), ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True
module.named_parameters(recurse=False),
strict=True,
): ):
assert n_p_ema == n_p, "Parameter names don't match for EMA model update" assert n_p_ema == n_p, "Parameter names don't match for EMA model update"
if isinstance(p, dict): if isinstance(p, dict):

View File

@@ -193,12 +193,7 @@ class VQBeTConfig(PreTrainedConfig):
@property @property
def action_delta_indices(self) -> list: def action_delta_indices(self) -> list:
return list( return list(range(1 - self.n_obs_steps, self.n_action_pred_token + self.action_chunk_size - 1))
range(
1 - self.n_obs_steps,
self.n_action_pred_token + self.action_chunk_size - 1,
)
)
@property @property
def reward_delta_indices(self) -> None: def reward_delta_indices(self) -> None:

View File

@@ -29,11 +29,7 @@ from torch import Tensor, nn
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.utils import ( from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
get_device_from_parameters,
get_output_shape,
populate_queues,
)
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.common.policies.vqbet.vqbet_utils import GPT, ResidualVQ from lerobot.common.policies.vqbet.vqbet_utils import GPT, ResidualVQ
@@ -328,8 +324,7 @@ class VQBeTModel(nn.Module):
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT. # To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
self.state_projector = MLP( self.state_projector = MLP(
config.robot_state_feature.shape[0], config.robot_state_feature.shape[0], hidden_channels=[self.config.gpt_input_dim]
hidden_channels=[self.config.gpt_input_dim],
) )
self.rgb_feature_projector = MLP( self.rgb_feature_projector = MLP(
self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim] self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]
@@ -359,11 +354,7 @@ class VQBeTModel(nn.Module):
) )
# Separate batch and sequence dims. # Separate batch and sequence dims.
img_features = einops.rearrange( img_features = einops.rearrange(
img_features, img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images
"(b s n) ... -> b s n ...",
b=batch_size,
s=n_obs_steps,
n=self.num_images,
) )
# Arrange prior and current observation step tokens as shown in the class docstring. # Arrange prior and current observation step tokens as shown in the class docstring.
@@ -400,11 +391,7 @@ class VQBeTModel(nn.Module):
# Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional). # Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional).
if len_additional_action_token > 0: if len_additional_action_token > 0:
features = torch.cat( features = torch.cat(
[ [features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1
features[:, historical_act_pred_index],
features[:, -len_additional_action_token:],
],
dim=1,
) )
else: else:
features = features[:, historical_act_pred_index] features = features[:, historical_act_pred_index]
@@ -527,13 +514,7 @@ class VQBeTHead(nn.Module):
cbet_secondary_logits = self.map_to_cbet_preds_secondary_bin( cbet_secondary_logits = self.map_to_cbet_preds_secondary_bin(
torch.cat( torch.cat(
( (x, F.one_hot(sampled_primary_centers, num_classes=self.config.vqvae_n_embed)),
x,
F.one_hot(
sampled_primary_centers,
num_classes=self.config.vqvae_n_embed,
),
),
axis=1, axis=1,
) )
) )
@@ -551,9 +532,7 @@ class VQBeTHead(nn.Module):
else: else:
cbet_logits = self.map_to_cbet_preds_bin(x) cbet_logits = self.map_to_cbet_preds_bin(x)
cbet_logits = einops.rearrange( cbet_logits = einops.rearrange(
cbet_logits, cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers
"(NT) (G C) -> (NT) G C",
G=self.vqvae_model.vqvae_num_layers,
) )
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1) cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
NT, G, choices = cbet_probs.shape NT, G, choices = cbet_probs.shape
@@ -751,9 +730,7 @@ class VQBeTRgbEncoder(nn.Module):
def _replace_submodules( def _replace_submodules(
root_module: nn.Module, root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
predicate: Callable[[nn.Module], bool],
func: Callable[[nn.Module], nn.Module],
) -> nn.Module: ) -> nn.Module:
""" """
Args: Args:

View File

@@ -377,10 +377,7 @@ class ResidualVQ(nn.Module):
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
VectorQuantize( VectorQuantize(
dim=codebook_dim, dim=codebook_dim, codebook_dim=codebook_dim, accept_image_fmap=accept_image_fmap, **kwargs
codebook_dim=codebook_dim,
accept_image_fmap=accept_image_fmap,
**kwargs,
) )
for _ in range(num_quantizers) for _ in range(num_quantizers)
] ]

View File

@@ -297,11 +297,7 @@ class IntelRealSenseCamera:
if self.fps and self.capture_width and self.capture_height: if self.fps and self.capture_width and self.capture_height:
# TODO(rcadene): can we set rgb8 directly? # TODO(rcadene): can we set rgb8 directly?
config.enable_stream( config.enable_stream(
rs.stream.color, rs.stream.color, self.capture_width, self.capture_height, rs.format.rgb8, self.fps
self.capture_width,
self.capture_height,
rs.format.rgb8,
self.fps,
) )
else: else:
config.enable_stream(rs.stream.color) config.enable_stream(rs.stream.color)
@@ -309,11 +305,7 @@ class IntelRealSenseCamera:
if self.use_depth: if self.use_depth:
if self.fps and self.capture_width and self.capture_height: if self.fps and self.capture_width and self.capture_height:
config.enable_stream( config.enable_stream(
rs.stream.depth, rs.stream.depth, self.capture_width, self.capture_height, rs.format.z16, self.fps
self.capture_width,
self.capture_height,
rs.format.z16,
self.fps,
) )
else: else:
config.enable_stream(rs.stream.depth) config.enable_stream(rs.stream.depth)

View File

@@ -41,9 +41,7 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> list[C
cameras[key] = OpenCVCamera(cfg) cameras[key] = OpenCVCamera(cfg)
elif cfg.type == "intelrealsense": elif cfg.type == "intelrealsense":
from lerobot.common.robot_devices.cameras.intelrealsense import ( from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera
IntelRealSenseCamera,
)
cameras[key] = IntelRealSenseCamera(cfg) cameras[key] = IntelRealSenseCamera(cfg)
else: else:
@@ -60,9 +58,7 @@ def make_camera(camera_type, **kwargs) -> Camera:
return OpenCVCamera(config) return OpenCVCamera(config)
elif camera_type == "intelrealsense": elif camera_type == "intelrealsense":
from lerobot.common.robot_devices.cameras.intelrealsense import ( from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera
IntelRealSenseCamera,
)
config = IntelRealSenseCameraConfig(**kwargs) config = IntelRealSenseCameraConfig(**kwargs)
return IntelRealSenseCamera(config) return IntelRealSenseCamera(config)

View File

@@ -23,10 +23,7 @@ import numpy as np
import tqdm import tqdm
from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig
from lerobot.common.robot_devices.utils import ( from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
)
from lerobot.common.utils.utils import capture_timestamp_utc from lerobot.common.utils.utils import capture_timestamp_utc
PROTOCOL_VERSION = 2.0 PROTOCOL_VERSION = 2.0
@@ -787,12 +784,7 @@ class DynamixelMotorsBus:
f"{self.packet_handler.getTxRxResult(comm)}" f"{self.packet_handler.getTxRxResult(comm)}"
) )
def write( def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
self,
data_name,
values: int | float | np.ndarray,
motor_names: str | list[str] | None = None,
):
if not self.is_connected: if not self.is_connected:
raise RobotDeviceNotConnectedError( raise RobotDeviceNotConnectedError(
f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`." f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."

View File

@@ -23,10 +23,7 @@ import numpy as np
import tqdm import tqdm
from lerobot.common.robot_devices.motors.configs import FeetechMotorsBusConfig from lerobot.common.robot_devices.motors.configs import FeetechMotorsBusConfig
from lerobot.common.robot_devices.utils import ( from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
)
from lerobot.common.utils.utils import capture_timestamp_utc from lerobot.common.utils.utils import capture_timestamp_utc
PROTOCOL_VERSION = 0 PROTOCOL_VERSION = 0
@@ -812,12 +809,7 @@ class FeetechMotorsBus:
f"{self.packet_handler.getTxRxResult(comm)}" f"{self.packet_handler.getTxRxResult(comm)}"
) )
def write( def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
self,
data_name,
values: int | float | np.ndarray,
motor_names: str | list[str] | None = None,
):
if not self.is_connected: if not self.is_connected:
raise RobotDeviceNotConnectedError( raise RobotDeviceNotConnectedError(
f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`." f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."

View File

@@ -30,9 +30,7 @@ class MotorsBus(Protocol):
def write(self): ... def write(self): ...
def make_motors_buses_from_configs( def make_motors_buses_from_configs(motors_bus_configs: dict[str, MotorsBusConfig]) -> list[MotorsBus]:
motors_bus_configs: dict[str, MotorsBusConfig],
) -> list[MotorsBus]:
motors_buses = {} motors_buses = {}
for key, cfg in motors_bus_configs.items(): for key, cfg in motors_bus_configs.items():

View File

@@ -207,10 +207,7 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
print("Calibrate elbow_flex") print("Calibrate elbow_flex")
calib["elbow_flex"] = move_to_calibrate( calib["elbow_flex"] = move_to_calibrate(
arm, arm, "elbow_flex", positive_first=False, in_between_move_hook=in_between_move_hook
"elbow_flex",
positive_first=False,
in_between_move_hook=in_between_move_hook,
) )
calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024) calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024)
@@ -242,11 +239,7 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
} }
arm.write("Goal_Position", list(positions.values()), list(positions.keys())) arm.write("Goal_Position", list(positions.values()), list(positions.keys()))
arm.write( arm.write("Goal_Position", round(calib["shoulder_lift"]["zero_pos"] - 1600), "shoulder_lift")
"Goal_Position",
round(calib["shoulder_lift"]["zero_pos"] - 1600),
"shoulder_lift",
)
time.sleep(2) time.sleep(2)
arm.write("Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex") arm.write("Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex")
time.sleep(2) time.sleep(2)
@@ -257,11 +250,7 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
print("Calibrate wrist_roll") print("Calibrate wrist_roll")
calib["wrist_roll"] = move_to_calibrate( calib["wrist_roll"] = move_to_calibrate(
arm, arm, "wrist_roll", invert_drive_mode=True, positive_first=False, while_move_hook=while_move_hook
"wrist_roll",
invert_drive_mode=True,
positive_first=False,
while_move_hook=while_move_hook,
) )
arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"], "wrist_roll") arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"], "wrist_roll")

View File

@@ -61,9 +61,7 @@ def calibrate_follower_arm(motors_bus, calib_dir_str):
calib_dir.mkdir(parents=True, exist_ok=True) calib_dir.mkdir(parents=True, exist_ok=True)
calib_file = calib_dir / "main_follower.json" calib_file = calib_dir / "main_follower.json"
try: try:
from lerobot.common.robot_devices.robots.feetech_calibration import ( from lerobot.common.robot_devices.robots.feetech_calibration import run_arm_manual_calibration
run_arm_manual_calibration,
)
except ImportError: except ImportError:
print("[WARNING] Calibration function not available. Skipping calibration.") print("[WARNING] Calibration function not available. Skipping calibration.")
return return
@@ -118,14 +116,7 @@ def run_lekiwi(robot_config):
robot = LeKiwi(motors_bus) robot = LeKiwi(motors_bus)
# Define the expected arm motor IDs. # Define the expected arm motor IDs.
arm_motor_ids = [ arm_motor_ids = ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"]
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
]
# Disable torque for each arm motor. # Disable torque for each arm motor.
for motor in arm_motor_ids: for motor in arm_motor_ids:
@@ -139,9 +130,7 @@ def run_lekiwi(robot_config):
images_lock = threading.Lock() images_lock = threading.Lock()
stop_event = threading.Event() stop_event = threading.Event()
cam_thread = threading.Thread( cam_thread = threading.Thread(
target=run_camera_capture, target=run_camera_capture, args=(cameras, images_lock, latest_images_dict, stop_event), daemon=True
args=(cameras, images_lock, latest_images_dict, stop_event),
daemon=True,
) )
cam_thread.start() cam_thread.start()

View File

@@ -25,14 +25,9 @@ import zmq
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
from lerobot.common.robot_devices.motors.feetech import TorqueMode from lerobot.common.robot_devices.motors.feetech import TorqueMode
from lerobot.common.robot_devices.motors.utils import ( from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
MotorsBus,
make_motors_buses_from_configs,
)
from lerobot.common.robot_devices.robots.configs import LeKiwiRobotConfig from lerobot.common.robot_devices.robots.configs import LeKiwiRobotConfig
from lerobot.common.robot_devices.robots.feetech_calibration import ( from lerobot.common.robot_devices.robots.feetech_calibration import run_arm_manual_calibration
run_arm_manual_calibration,
)
from lerobot.common.robot_devices.robots.utils import get_arm_id from lerobot.common.robot_devices.robots.utils import get_arm_id
from lerobot.common.robot_devices.utils import RobotDeviceNotConnectedError from lerobot.common.robot_devices.utils import RobotDeviceNotConnectedError
@@ -329,11 +324,7 @@ class MobileManipulator:
socks = dict(poller.poll(15)) socks = dict(poller.poll(15))
if self.video_socket not in socks or socks[self.video_socket] != zmq.POLLIN: if self.video_socket not in socks or socks[self.video_socket] != zmq.POLLIN:
# No new data arrived → reuse ALL old data # No new data arrived → reuse ALL old data
return ( return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
self.last_frames,
self.last_present_speed,
self.last_remote_arm_state,
)
# Drain all messages, keep only the last # Drain all messages, keep only the last
last_msg = None last_msg = None
@@ -346,11 +337,7 @@ class MobileManipulator:
if not last_msg: if not last_msg:
# No new message → also reuse old # No new message → also reuse old
return ( return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
self.last_frames,
self.last_present_speed,
self.last_remote_arm_state,
)
# Decode only the final message # Decode only the final message
try: try:
@@ -388,11 +375,7 @@ class MobileManipulator:
except Exception as e: except Exception as e:
print(f"[DEBUG] Error decoding video message: {e}") print(f"[DEBUG] Error decoding video message: {e}")
# If decode fails, fall back to old data # If decode fails, fall back to old data
return ( return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
self.last_frames,
self.last_present_speed,
self.last_remote_arm_state,
)
return frames, present_speed, remote_arm_state_tensor return frames, present_speed, remote_arm_state_tensor
@@ -478,11 +461,7 @@ class MobileManipulator:
body_state = self.wheel_raw_to_body(present_speed) body_state = self.wheel_raw_to_body(present_speed)
body_state_mm = ( body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s
body_state[0] * 1000.0,
body_state[1] * 1000.0,
body_state[2],
) # Convert x,y to mm/s
wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32) wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32)
combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0) combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0)
@@ -641,11 +620,7 @@ class MobileManipulator:
# Convert each wheels angular speed (deg/s) to a raw integer. # Convert each wheels angular speed (deg/s) to a raw integer.
wheel_raw = [MobileManipulator.degps_to_raw(deg) for deg in wheel_degps] wheel_raw = [MobileManipulator.degps_to_raw(deg) for deg in wheel_degps]
return { return {"left_wheel": wheel_raw[0], "back_wheel": wheel_raw[1], "right_wheel": wheel_raw[2]}
"left_wheel": wheel_raw[0],
"back_wheel": wheel_raw[1],
"right_wheel": wheel_raw[2],
}
def wheel_raw_to_body( def wheel_raw_to_body(
self, wheel_raw: dict, wheel_radius: float = 0.05, base_radius: float = 0.125 self, wheel_raw: dict, wheel_radius: float = 0.05, base_radius: float = 0.125

View File

@@ -72,9 +72,7 @@ def make_robot_from_config(config: RobotConfig):
return ManipulatorRobot(config) return ManipulatorRobot(config)
elif isinstance(config, LeKiwiRobotConfig): elif isinstance(config, LeKiwiRobotConfig):
from lerobot.common.robot_devices.robots.mobile_manipulator import ( from lerobot.common.robot_devices.robots.mobile_manipulator import MobileManipulator
MobileManipulator,
)
return MobileManipulator(config) return MobileManipulator(config)
else: else:

View File

@@ -48,8 +48,7 @@ class RobotDeviceNotConnectedError(Exception):
"""Exception raised when the robot device is not connected.""" """Exception raised when the robot device is not connected."""
def __init__( def __init__(
self, self, message="This robot device is not connected. Try calling `robot_device.connect()` first."
message="This robot device is not connected. Try calling `robot_device.connect()` first.",
): ):
self.message = message self.message = message
super().__init__(self.message) super().__init__(self.message)

View File

@@ -42,11 +42,7 @@ def deserialize_python_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> Non
""" """
Restores the rng state for `random` from a dictionary produced by `serialize_python_rng_state()`. Restores the rng state for `random` from a dictionary produced by `serialize_python_rng_state()`.
""" """
py_state = ( py_state = (rng_state_dict["py_rng_version"].item(), tuple(rng_state_dict["py_rng_state"].tolist()), None)
rng_state_dict["py_rng_version"].item(),
tuple(rng_state_dict["py_rng_state"].tolist()),
None,
)
random.setstate(py_state) random.setstate(py_state)

View File

@@ -42,10 +42,7 @@ def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> lis
args = sys.argv[1:] args = sys.argv[1:]
attr_level_args = [] attr_level_args = []
detect_string = f"--{field_name}." detect_string = f"--{field_name}."
exclude_strings = ( exclude_strings = (f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=", f"--{field_name}.{PATH_KEY}=")
f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=",
f"--{field_name}.{PATH_KEY}=",
)
for arg in args: for arg in args:
if arg.startswith(detect_string) and not arg.startswith(exclude_strings): if arg.startswith(detect_string) and not arg.startswith(exclude_strings):
denested_arg = f"--{arg.removeprefix(detect_string)}" denested_arg = f"--{arg.removeprefix(detect_string)}"

View File

@@ -26,11 +26,7 @@ from huggingface_hub.errors import HfHubHTTPError
from lerobot.common.optim.optimizers import OptimizerConfig from lerobot.common.optim.optimizers import OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.common.utils.hub import HubMixin from lerobot.common.utils.hub import HubMixin
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
auto_select_torch_device,
is_amp_available,
is_torch_device_available,
)
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
# Generic variable that is either PreTrainedConfig or a subclass thereof # Generic variable that is either PreTrainedConfig or a subclass thereof

View File

@@ -38,12 +38,7 @@ def get_motor_bus_cls(brand: str) -> tuple:
FeetechMotorsBus, FeetechMotorsBus,
) )
return ( return FeetechMotorsBusConfig, FeetechMotorsBus, MODEL_BAUDRATE_TABLE, SCS_SERIES_BAUDRATE_TABLE
FeetechMotorsBusConfig,
FeetechMotorsBus,
MODEL_BAUDRATE_TABLE,
SCS_SERIES_BAUDRATE_TABLE,
)
elif brand == "dynamixel": elif brand == "dynamixel":
from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig
@@ -53,12 +48,7 @@ def get_motor_bus_cls(brand: str) -> tuple:
DynamixelMotorsBus, DynamixelMotorsBus,
) )
return ( return DynamixelMotorsBusConfig, DynamixelMotorsBus, MODEL_BAUDRATE_TABLE, X_SERIES_BAUDRATE_TABLE
DynamixelMotorsBusConfig,
DynamixelMotorsBus,
MODEL_BAUDRATE_TABLE,
X_SERIES_BAUDRATE_TABLE,
)
else: else:
raise ValueError( raise ValueError(
@@ -174,25 +164,12 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument("--port", type=str, required=True, help="Motors bus port (e.g. dynamixel,feetech)")
"--port",
type=str,
required=True,
help="Motors bus port (e.g. dynamixel,feetech)",
)
parser.add_argument("--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)") parser.add_argument("--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)")
parser.add_argument("--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)") parser.add_argument("--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)")
parser.add_argument("--ID", type=int, required=True, help="Desired ID of the current motor (e.g. 1,2,3)")
parser.add_argument( parser.add_argument(
"--ID", "--baudrate", type=int, default=1000000, help="Desired baudrate for the motor (default: 1000000)"
type=int,
required=True,
help="Desired ID of the current motor (e.g. 1,2,3)",
)
parser.add_argument(
"--baudrate",
type=int,
default=1000000,
help="Desired baudrate for the motor (default: 1000000)",
) )
args = parser.parse_args() args = parser.parse_args()

View File

@@ -149,11 +149,7 @@ def init_sim_calibration(robot, cfg):
axis_directions = np.array(cfg.get("axis_directions", [1])) axis_directions = np.array(cfg.get("axis_directions", [1]))
offsets = np.array(cfg.get("offsets", [0])) * np.pi offsets = np.array(cfg.get("offsets", [0])) * np.pi
return { return {"start_pos": start_pos, "axis_directions": axis_directions, "offsets": offsets}
"start_pos": start_pos,
"axis_directions": axis_directions,
"offsets": offsets,
}
def real_positions_to_sim(real_positions, axis_directions, start_pos, offsets): def real_positions_to_sim(real_positions, axis_directions, start_pos, offsets):
@@ -203,7 +199,6 @@ def record(
run_compute_stats: bool = True, run_compute_stats: bool = True,
) -> LeRobotDataset: ) -> LeRobotDataset:
# Load pretrained policy # Load pretrained policy
policy = None policy = None
if pretrained_policy_name_or_path is not None: if pretrained_policy_name_or_path is not None:
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides) policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
@@ -246,11 +241,7 @@ def record(
shape = env.observation_space[key].shape shape = env.observation_space[key].shape
if not key.startswith("observation.image."): if not key.startswith("observation.image."):
key = "observation.image." + key key = "observation.image." + key
features[key] = { features[key] = {"dtype": "video", "names": ["channels", "height", "width"], "shape": shape}
"dtype": "video",
"names": ["channels", "height", "width"],
"shape": shape,
}
for key, obs_key in state_keys_dict.items(): for key, obs_key in state_keys_dict.items():
features[key] = { features[key] = {
@@ -259,11 +250,7 @@ def record(
"shape": env.observation_space[obs_key].shape, "shape": env.observation_space[obs_key].shape,
} }
features["action"] = { features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None}
"dtype": "float32",
"shape": env.action_space.shape,
"names": None,
}
# Create empty dataset or load existing saved episodes # Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy) sanity_check_dataset_name(repo_id, policy)
@@ -374,12 +361,7 @@ def record(
def replay( def replay(
env, env, root: Path, repo_id: str, episode: int, fps: int | None = None, local_files_only: bool = True
root: Path,
repo_id: str,
episode: int,
fps: int | None = None,
local_files_only: bool = True,
): ):
env = env() env = env()
@@ -426,10 +408,7 @@ if __name__ == "__main__":
parser_record = subparsers.add_parser("record", parents=[base_parser]) parser_record = subparsers.add_parser("record", parents=[base_parser])
parser_record.add_argument( parser_record.add_argument(
"--fps", "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
type=none_or_int,
default=None,
help="Frames per second (set to None to disable)",
) )
parser_record.add_argument( parser_record.add_argument(
"--root", "--root",
@@ -507,19 +486,9 @@ if __name__ == "__main__":
default=0, default=0,
help="Resume recording on an existing dataset.", help="Resume recording on an existing dataset.",
) )
parser_record.add_argument(
"--assign-rewards",
type=int,
default=0,
help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.",
)
parser_replay = subparsers.add_parser("replay", parents=[base_parser]) parser_replay = subparsers.add_parser("replay", parents=[base_parser])
parser_replay.add_argument( parser_replay.add_argument(
"--fps", "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
type=none_or_int,
default=None,
help="Frames per second (set to None to disable)",
) )
parser_replay.add_argument( parser_replay.add_argument(
"--root", "--root",

View File

@@ -293,8 +293,7 @@ def eval_policy(
seeds = None seeds = None
else: else:
seeds = range( seeds = range(
start_seed + (batch_ix * env.num_envs), start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
start_seed + ((batch_ix + 1) * env.num_envs),
) )
rollout_data = rollout( rollout_data = rollout(
env, env,
@@ -414,11 +413,7 @@ def eval_policy(
def _compile_episode_data( def _compile_episode_data(
rollout_data: dict, rollout_data: dict, done_indices: Tensor, start_episode_index: int, start_data_index: int, fps: float
done_indices: Tensor,
start_episode_index: int,
start_data_index: int,
fps: float,
) -> dict: ) -> dict:
"""Convenience function for `eval_policy(return_episode_data=True)` """Convenience function for `eval_policy(return_episode_data=True)`
@@ -486,10 +481,7 @@ def eval_main(cfg: EvalPipelineConfig):
) )
policy.eval() policy.eval()
with ( with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
torch.no_grad(),
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
):
info = eval_policy( info = eval_policy(
env, env,
policy, policy,

View File

@@ -196,11 +196,7 @@ def train(cfg: TrainPipelineConfig):
} }
train_tracker = MetricsTracker( train_tracker = MetricsTracker(
cfg.batch_size, cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step
dataset.num_frames,
dataset.num_episodes,
train_metrics,
initial_step=step,
) )
logging.info("Start offline training on a fixed dataset") logging.info("Start offline training on a fixed dataset")
@@ -271,11 +267,7 @@ def train(cfg: TrainPipelineConfig):
"eval_s": AverageMeter("eval_s", ":.3f"), "eval_s": AverageMeter("eval_s", ":.3f"),
} }
eval_tracker = MetricsTracker( eval_tracker = MetricsTracker(
cfg.batch_size, cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
dataset.num_frames,
dataset.num_episodes,
eval_metrics,
initial_step=step,
) )
eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s") eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward") eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")

View File

@@ -81,11 +81,7 @@ def run_server(
static_folder: Path, static_folder: Path,
template_folder: Path, template_folder: Path,
): ):
app = Flask( app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
__name__,
static_folder=static_folder.resolve(),
template_folder=template_folder.resolve(),
)
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
@app.route("/") @app.route("/")
@@ -201,8 +197,7 @@ def run_server(
] ]
response = requests.get( response = requests.get(
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", timeout=5
timeout=5,
) )
response.raise_for_status() response.raise_for_status()
# Split into lines and parse each line as JSON # Split into lines and parse each line as JSON
@@ -287,8 +282,7 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
repo_id = dataset.repo_id repo_id = dataset.repo_id
url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format( url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
episode_chunk=int(episode_index) // dataset.chunks_size, episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index
episode_index=episode_index,
) )
df = pd.read_parquet(url) df = pd.read_parquet(url)
data = df[selected_columns] # Select specific columns data = df[selected_columns] # Select specific columns
@@ -337,8 +331,7 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) ->
def get_dataset_info(repo_id: str) -> IterableNamespace: def get_dataset_info(repo_id: str) -> IterableNamespace:
response = requests.get( response = requests.get(
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", timeout=5
timeout=5,
) )
response.raise_for_status() # Raises an HTTPError for bad responses response.raise_for_status() # Raises an HTTPError for bad responses
dataset_info = response.json() dataset_info = response.json()

View File

@@ -71,7 +71,7 @@ dependencies = [
"rerun-sdk>=0.21.0", "rerun-sdk>=0.21.0",
"termcolor>=2.4.0", "termcolor>=2.4.0",
"torch>=2.2.1,<2.7", "torch>=2.2.1,<2.7",
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", "torchcodec==0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
"torchmetrics>=1.6.0", "torchmetrics>=1.6.0",
"torchvision>=0.21.0", "torchvision>=0.21.0",
"transformers>=4.47.0", "transformers>=4.47.0",

View File

@@ -59,33 +59,16 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
"action": { "action": {
"dtype": "float32", "dtype": "float32",
"shape": (6,), "shape": (6,),
"names": [ "names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
],
}, },
"observation.state": { "observation.state": {
"dtype": "float32", "dtype": "float32",
"shape": (6,), "shape": (6,),
"names": [ "names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
],
}, },
} }
info = info_factory( info = info_factory(
total_episodes=1, total_episodes=1, total_frames=1, camera_features=camera_features, motor_features=motor_features
total_frames=1,
camera_features=camera_features,
motor_features=motor_features,
) )
ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info) ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info)
return ds_meta return ds_meta
@@ -98,8 +81,7 @@ def test_get_policy_and_config_classes(policy_name: str):
policy_cfg = make_policy_config(policy_name) policy_cfg = make_policy_config(policy_name)
assert policy_cls.name == policy_name assert policy_cls.name == policy_name
assert issubclass( assert issubclass(
policy_cfg.__class__, policy_cfg.__class__, inspect.signature(policy_cls.__init__).parameters["config"].annotation
inspect.signature(policy_cls.__init__).parameters["config"].annotation,
) )
@@ -110,13 +92,7 @@ def test_get_policy_and_config_classes(policy_name: str):
("lerobot/pusht", "pusht", {}, "diffusion", {}), ("lerobot/pusht", "pusht", {}, "diffusion", {}),
("lerobot/pusht", "pusht", {}, "vqbet", {}), ("lerobot/pusht", "pusht", {}, "vqbet", {}),
("lerobot/pusht", "pusht", {}, "act", {}), ("lerobot/pusht", "pusht", {}, "act", {}),
( ("lerobot/aloha_sim_insertion_human", "aloha", {"task": "AlohaInsertion-v0"}, "act", {}),
"lerobot/aloha_sim_insertion_human",
"aloha",
{"task": "AlohaInsertion-v0"},
"act",
{},
),
( (
"lerobot/aloha_sim_insertion_scripted", "lerobot/aloha_sim_insertion_scripted",
"aloha", "aloha",

View File

@@ -390,8 +390,7 @@ def test_record_with_event_exit_early(tmp_path, request, robot_type, mock):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"robot_type, mock, num_image_writer_processes", "robot_type, mock, num_image_writer_processes", [("koch", True, 0), ("koch", True, 1)]
[("koch", True, 0), ("koch", True, 1)],
) )
@require_robot @require_robot
def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, num_image_writer_processes): def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, num_image_writer_processes):

View File

@@ -40,10 +40,7 @@ import pytest
import torch import torch
from lerobot.common.robot_devices.robots.utils import make_robot from lerobot.common.robot_devices.robots.utils import make_robot
from lerobot.common.robot_devices.utils import ( from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
)
from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot

View File

@@ -26,9 +26,7 @@ from lerobot import available_cameras, available_motors, available_robots
from lerobot.common.robot_devices.cameras.utils import Camera from lerobot.common.robot_devices.cameras.utils import Camera
from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device
from lerobot.common.robot_devices.motors.utils import MotorsBus from lerobot.common.robot_devices.motors.utils import MotorsBus
from lerobot.common.robot_devices.motors.utils import ( from lerobot.common.robot_devices.motors.utils import make_motors_bus as make_motors_bus_device
make_motors_bus as make_motors_bus_device,
)
from lerobot.common.utils.import_utils import is_package_available from lerobot.common.utils.import_utils import is_package_available
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu" DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu"
@@ -287,14 +285,7 @@ def mock_calibration_dir(calibration_dir):
"start_pos": [1442, 843, 2166, 2849, 1988, 1835], "start_pos": [1442, 843, 2166, 2849, 1988, 1835],
"end_pos": [2440, 1869, -1106, -1848, -926, 3235], "end_pos": [2440, 1869, -1106, -1848, -926, 3235],
"calib_mode": ["DEGREE", "DEGREE", "DEGREE", "DEGREE", "DEGREE", "LINEAR"], "calib_mode": ["DEGREE", "DEGREE", "DEGREE", "DEGREE", "DEGREE", "LINEAR"],
"motor_names": [ "motor_names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
],
} }
Path(str(calibration_dir)).mkdir(parents=True, exist_ok=True) Path(str(calibration_dir)).mkdir(parents=True, exist_ok=True)
with open(calibration_dir / "main_follower.json", "w") as f: with open(calibration_dir / "main_follower.json", "w") as f:

View File

@@ -18,10 +18,7 @@ from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
@pytest.fixture @pytest.fixture
def mock_metrics(): def mock_metrics():
return { return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")}
"loss": AverageMeter("loss", ":.3f"),
"accuracy": AverageMeter("accuracy", ":.2f"),
}
def test_average_meter_initialization(): def test_average_meter_initialization():
@@ -61,11 +58,7 @@ def test_average_meter_str():
def test_metrics_tracker_initialization(mock_metrics): def test_metrics_tracker_initialization(mock_metrics):
tracker = MetricsTracker( tracker = MetricsTracker(
batch_size=32, batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics, initial_step=10
num_frames=1000,
num_episodes=50,
metrics=mock_metrics,
initial_step=10,
) )
assert tracker.steps == 10 assert tracker.steps == 10
assert tracker.samples == 10 * 32 assert tracker.samples == 10 * 32
@@ -77,11 +70,7 @@ def test_metrics_tracker_initialization(mock_metrics):
def test_metrics_tracker_step(mock_metrics): def test_metrics_tracker_step(mock_metrics):
tracker = MetricsTracker( tracker = MetricsTracker(
batch_size=32, batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics, initial_step=5
num_frames=1000,
num_episodes=50,
metrics=mock_metrics,
initial_step=5,
) )
tracker.step() tracker.step()
assert tracker.steps == 6 assert tracker.steps == 6