Compare commits

...

15 Commits

Author SHA1 Message Date
Steven Palma
e511e7eda5 Merge branch 'main' into fix/lint_warnings 2025-03-10 09:39:00 +01:00
Steven Palma
f5ed3723f0 fix(tests): typo in fixture name 2025-03-07 18:21:55 +01:00
Steven Palma
b104be0d04 Merge branch 'main' into fix/lint_warnings 2025-03-07 18:07:30 +01:00
Steven Palma
f9e4a1f5c4 chore(style): fix format 2025-03-07 17:44:18 +01:00
Steven Palma
0eb56cec14 fix(tests): remove lint warnings/errors 2025-03-07 17:44:18 +01:00
Steven Palma
e59ef036e1 fix(lerobot/common/datasets): remove lint warnings/errors 2025-03-07 16:50:22 +01:00
Steven Palma
9b380eaf67 fix(lerobot/common/envs): remove lint warnings/errors 2025-03-07 16:50:22 +01:00
Steven Palma
1187604ba0 fix(lerobot/common/optim): remove lint warnings/errors 2025-03-07 16:50:22 +01:00
Steven Palma
5c6f2d2cd0 fix(lerobot/common/policies): remove lint warnings/errors 2025-03-07 16:50:22 +01:00
Steven Palma
652fedf69c fix(lerobot/common/robot_devices): remove lint warnings/errors 2025-03-07 16:50:22 +01:00
Steven Palma
85214ec303 fix(lerobot/common/utils): remove lint warnings/errors 2025-03-07 16:50:22 +01:00
Steven Palma
dffa5a18db fix(lerobot/configs): remove lint warning/errors 2025-03-07 16:50:22 +01:00
Steven Palma
301f152a34 fix(lerobot/scripts): remove lint warnings/errors 2025-03-07 16:50:21 +01:00
Steven Palma
0ed08c0b1f fix(examples): remove lint warnings/errors 2025-03-07 14:26:33 +01:00
Steven Palma
254bc707e7 fix(benchmarks): remove lint warnings/errors 2025-03-07 14:25:42 +01:00
80 changed files with 552 additions and 453 deletions

View File

@@ -67,6 +67,7 @@ def parse_int_or_none(value) -> int | None:
def check_datasets_formats(repo_ids: list) -> None: def check_datasets_formats(repo_ids: list) -> None:
for repo_id in repo_ids: for repo_id in repo_ids:
dataset = LeRobotDataset(repo_id) dataset = LeRobotDataset(repo_id)
# TODO(Steven): Seems this API has changed
if dataset.video: if dataset.video:
raise ValueError( raise ValueError(
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}" f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"

View File

@@ -222,7 +222,7 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
if __name__ == "__main__": if __name__ == "__main__":
# To try this script, modify the repo id with your own HuggingFace user (e.g cadene/pusht) # To try this script, modify the repo id with your own HuggingFace user (e.g cadene/pusht)
repo_id = "lerobot/pusht" repository_id = "lerobot/pusht"
modes = ["video", "image", "keypoints"] modes = ["video", "image", "keypoints"]
# Uncomment if you want to try with a specific mode # Uncomment if you want to try with a specific mode
@@ -230,13 +230,13 @@ if __name__ == "__main__":
# modes = ["image"] # modes = ["image"]
# modes = ["keypoints"] # modes = ["keypoints"]
raw_dir = Path("data/lerobot-raw/pusht_raw") data_dir = Path("data/lerobot-raw/pusht_raw")
for mode in modes: for available_mode in modes:
if mode in ["image", "keypoints"]: if available_mode in ["image", "keypoints"]:
repo_id += f"_{mode}" repository_id += f"_{available_mode}"
# download and load raw dataset, create LeRobotDataset, populate it, push to hub # download and load raw dataset, create LeRobotDataset, populate it, push to hub
main(raw_dir, repo_id=repo_id, mode=mode) main(data_dir, repo_id=repository_id, mode=available_mode)
# Uncomment if you want to load the local dataset and explore it # Uncomment if you want to load the local dataset and explore it
# dataset = LeRobotDataset(repo_id=repo_id) # dataset = LeRobotDataset(repo_id=repo_id)

View File

@@ -13,8 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from pprint import pformat
import torch import torch
@@ -98,17 +96,17 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
) )
else: else:
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.") raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
dataset = MultiLeRobotDataset( # dataset = MultiLeRobotDataset(
cfg.dataset.repo_id, # cfg.dataset.repo_id,
# TODO(aliberts): add proper support for multi dataset # # TODO(aliberts): add proper support for multi dataset
# delta_timestamps=delta_timestamps, # # delta_timestamps=delta_timestamps,
image_transforms=image_transforms, # image_transforms=image_transforms,
video_backend=cfg.dataset.video_backend, # video_backend=cfg.dataset.video_backend,
) # )
logging.info( # logging.info(
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: " # "Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
f"{pformat(dataset.repo_id_to_index, indent=2)}" # f"{pformat(dataset.repo_id_to_index, indent=2)}"
) # )
if cfg.dataset.use_imagenet_stats: if cfg.dataset.use_imagenet_stats:
for key in dataset.meta.camera_keys: for key in dataset.meta.camera_keys:

View File

@@ -81,21 +81,21 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
print(f"Error writing image {fpath}: {e}") print(f"Error writing image {fpath}: {e}")
def worker_thread_loop(queue: queue.Queue): def worker_thread_loop(task_queue: queue.Queue):
while True: while True:
item = queue.get() item = task_queue.get()
if item is None: if item is None:
queue.task_done() task_queue.task_done()
break break
image_array, fpath = item image_array, fpath = item
write_image(image_array, fpath) write_image(image_array, fpath)
queue.task_done() task_queue.task_done()
def worker_process(queue: queue.Queue, num_threads: int): def worker_process(task_queue: queue.Queue, num_threads: int):
threads = [] threads = []
for _ in range(num_threads): for _ in range(num_threads):
t = threading.Thread(target=worker_thread_loop, args=(queue,)) t = threading.Thread(target=worker_thread_loop, args=(task_queue,))
t.daemon = True t.daemon = True
t.start() t.start()
threads.append(t) threads.append(t)

View File

@@ -87,6 +87,7 @@ class LeRobotDatasetMetadata:
self.repo_id = repo_id self.repo_id = repo_id
self.revision = revision if revision else CODEBASE_VERSION self.revision = revision if revision else CODEBASE_VERSION
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
self.stats = None
try: try:
if force_cache_sync: if force_cache_sync:
@@ -102,10 +103,10 @@ class LeRobotDatasetMetadata:
def load_metadata(self): def load_metadata(self):
self.info = load_info(self.root) self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) check_version_compatibility(self.repo_id, self.version, CODEBASE_VERSION)
self.tasks, self.task_to_task_index = load_tasks(self.root) self.tasks, self.task_to_task_index = load_tasks(self.root)
self.episodes = load_episodes(self.root) self.episodes = load_episodes(self.root)
if self._version < packaging.version.parse("v2.1"): if self.version < packaging.version.parse("v2.1"):
self.stats = load_stats(self.root) self.stats = load_stats(self.root)
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes) self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
else: else:
@@ -127,7 +128,7 @@ class LeRobotDatasetMetadata:
) )
@property @property
def _version(self) -> packaging.version.Version: def version(self) -> packaging.version.Version:
"""Codebase version used to create this dataset.""" """Codebase version used to create this dataset."""
return packaging.version.parse(self.info["codebase_version"]) return packaging.version.parse(self.info["codebase_version"])
@@ -321,8 +322,9 @@ class LeRobotDatasetMetadata:
robot_type = robot.robot_type robot_type = robot.robot_type
if not all(cam.fps == fps for cam in robot.cameras.values()): if not all(cam.fps == fps for cam in robot.cameras.values()):
logging.warning( logging.warning(
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset." "Some cameras in your %s robot don't have an fps matching the fps of your dataset."
"In this case, frames from lower fps cameras will be repeated to fill in the blanks." "In this case, frames from lower fps cameras will be repeated to fill in the blanks.",
robot.robot_type,
) )
elif features is None: elif features is None:
raise ValueError( raise ValueError(
@@ -486,7 +488,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.meta = LeRobotDatasetMetadata( self.meta = LeRobotDatasetMetadata(
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
) )
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"): if self.episodes is not None and self.meta.version >= packaging.version.parse("v2.1"):
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes] episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
self.stats = aggregate_stats(episodes_stats) self.stats = aggregate_stats(episodes_stats)
@@ -518,7 +520,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self, self,
branch: str | None = None, branch: str | None = None,
tags: list | None = None, tags: list | None = None,
license: str | None = "apache-2.0", dataset_license: str | None = "apache-2.0",
tag_version: bool = True, tag_version: bool = True,
push_videos: bool = True, push_videos: bool = True,
private: bool = False, private: bool = False,
@@ -561,7 +563,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch): if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
card = create_lerobot_dataset_card( card = create_lerobot_dataset_card(
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs tags=tags, dataset_info=self.meta.info, license=dataset_license, **card_kwargs
) )
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch) card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
@@ -842,6 +844,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
None. None.
""" """
episode_buffer = None
if not episode_data: if not episode_data:
episode_buffer = self.episode_buffer episode_buffer = self.episode_buffer
@@ -1086,8 +1089,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True): for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
extra_keys = set(ds.features).difference(intersection_features) extra_keys = set(ds.features).difference(intersection_features)
logging.warning( logging.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the " "keys %s of %s were disabled as they are not contained in all the other datasets.",
"other datasets." extra_keys,
repo_id,
) )
self.disabled_features.update(extra_keys) self.disabled_features.update(extra_keys)

View File

@@ -53,7 +53,7 @@ def rechunk_recompress_array(group, name, chunks=None, chunk_length=None, compre
# rechunk recompress # rechunk recompress
group.move(name, tmp_key) group.move(name, tmp_key)
old_arr = group[tmp_key] old_arr = group[tmp_key]
n_copied, n_skipped, n_bytes_copied = zarr.copy( _n_copied, _n_skipped, _n_bytes_copied = zarr.copy(
source=old_arr, source=old_arr,
dest=group, dest=group,
name=name, name=name,
@@ -192,7 +192,7 @@ class ReplayBuffer:
else: else:
root = zarr.group(store=store) root = zarr.group(store=store)
# copy without recompression # copy without recompression
n_copied, n_skipped, n_bytes_copied = zarr.copy_store( _n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists
) )
data_group = root.create_group("data", overwrite=True) data_group = root.create_group("data", overwrite=True)
@@ -205,7 +205,7 @@ class ReplayBuffer:
if cks == value.chunks and cpr == value.compressor: if cks == value.chunks and cpr == value.compressor:
# copy without recompression # copy without recompression
this_path = "/data/" + key this_path = "/data/" + key
n_copied, n_skipped, n_bytes_copied = zarr.copy_store( _n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
source=src_store, source=src_store,
dest=store, dest=store,
source_path=this_path, source_path=this_path,
@@ -214,7 +214,7 @@ class ReplayBuffer:
) )
else: else:
# copy with recompression # copy with recompression
n_copied, n_skipped, n_bytes_copied = zarr.copy( _n_copied, _n_skipped, _n_bytes_copied = zarr.copy(
source=value, source=value,
dest=data_group, dest=data_group,
name=key, name=key,
@@ -275,7 +275,7 @@ class ReplayBuffer:
compressors = {} compressors = {}
if self.backend == "zarr": if self.backend == "zarr":
# recompression free copy # recompression free copy
n_copied, n_skipped, n_bytes_copied = zarr.copy_store( _n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
source=self.root.store, source=self.root.store,
dest=store, dest=store,
source_path="/meta", source_path="/meta",
@@ -297,7 +297,7 @@ class ReplayBuffer:
if cks == value.chunks and cpr == value.compressor: if cks == value.chunks and cpr == value.compressor:
# copy without recompression # copy without recompression
this_path = "/data/" + key this_path = "/data/" + key
n_copied, n_skipped, n_bytes_copied = zarr.copy_store( _n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
source=self.root.store, source=self.root.store,
dest=store, dest=store,
source_path=this_path, source_path=this_path,

View File

@@ -162,9 +162,9 @@ def download_raw(raw_dir: Path, repo_id: str):
) )
raw_dir.mkdir(parents=True, exist_ok=True) raw_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}") logging.info("Start downloading from huggingface.co/%s for %s", user_id, dataset_id)
snapshot_download(repo_id, repo_type="dataset", local_dir=raw_dir) snapshot_download(repo_id, repo_type="dataset", local_dir=raw_dir)
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}") logging.info("Finish downloading from huggingface.co/%s for %s", user_id, dataset_id)
def download_all_raw_datasets(data_dir: Path | None = None): def download_all_raw_datasets(data_dir: Path | None = None):

View File

@@ -72,7 +72,7 @@ def check_format(raw_dir) -> bool:
assert data[f"/observations/images/{camera}"].ndim == 2 assert data[f"/observations/images/{camera}"].ndim == 2
else: else:
assert data[f"/observations/images/{camera}"].ndim == 4 assert data[f"/observations/images/{camera}"].ndim == 4
b, h, w, c = data[f"/observations/images/{camera}"].shape _, h, w, c = data[f"/observations/images/{camera}"].shape
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided." assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
@@ -103,6 +103,7 @@ def load_from_raw(
state = torch.from_numpy(ep["/observations/qpos"][:]) state = torch.from_numpy(ep["/observations/qpos"][:])
action = torch.from_numpy(ep["/action"][:]) action = torch.from_numpy(ep["/action"][:])
velocity = None
if "/observations/qvel" in ep: if "/observations/qvel" in ep:
velocity = torch.from_numpy(ep["/observations/qvel"][:]) velocity = torch.from_numpy(ep["/observations/qvel"][:])
if "/observations/effort" in ep: if "/observations/effort" in ep:

View File

@@ -96,6 +96,7 @@ def from_raw_to_lerobot_format(
if fps is None: if fps is None:
fps = 30 fps = 30
# TODO(Steven): Is this meant to call cam_png_format.load_from_raw?
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes) data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
hf_dataset = to_hf_dataset(data_dict, video) hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset) episode_data_index = calculate_episode_data_index(hf_dataset)

View File

@@ -42,7 +42,9 @@ def check_format(raw_dir) -> bool:
return True return True
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None): def load_from_raw(
raw_dir: Path, videos_dir: Path, fps: int, _video: bool, _episodes: list[int] | None = None
):
# Load data stream that will be used as reference for the timestamps synchronization # Load data stream that will be used as reference for the timestamps synchronization
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet")) reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
if len(reference_files) == 0: if len(reference_files) == 0:

View File

@@ -55,7 +55,7 @@ def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers
num_images = len(imgs_array) num_images = len(imgs_array)
with ThreadPoolExecutor(max_workers=max_workers) as executor: with ThreadPoolExecutor(max_workers=max_workers) as executor:
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)] _ = [executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
def get_default_encoding() -> dict: def get_default_encoding() -> dict:
@@ -92,24 +92,23 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc
episode_data_index = {"from": [], "to": []} episode_data_index = {"from": [], "to": []}
current_episode = None current_episode = None
""" # The episode_index is a list of integers, each representing the episode index of the corresponding example.
The episode_index is a list of integers, each representing the episode index of the corresponding example. # For instance, the following is a valid episode_index:
For instance, the following is a valid episode_index: # [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2] #
# Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and # ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this: # {
{ # "from": [0, 3, 7],
"from": [0, 3, 7], # "to": [3, 7, 12]
"to": [3, 7, 12] # }
}
"""
if len(hf_dataset) == 0: if len(hf_dataset) == 0:
episode_data_index = { episode_data_index = {
"from": torch.tensor([]), "from": torch.tensor([]),
"to": torch.tensor([]), "to": torch.tensor([]),
} }
return episode_data_index return episode_data_index
idx = None
for idx, episode_idx in enumerate(hf_dataset["episode_index"]): for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
if episode_idx != current_episode: if episode_idx != current_episode:
# We encountered a new episode, so we append its starting location to the "from" list # We encountered a new episode, so we append its starting location to the "from" list

View File

@@ -23,6 +23,7 @@ from torchvision.transforms.v2 import Transform
from torchvision.transforms.v2 import functional as F # noqa: N812 from torchvision.transforms.v2 import functional as F # noqa: N812
# TODO(Steven): Missing transform() implementation
class RandomSubsetApply(Transform): class RandomSubsetApply(Transform):
"""Apply a random subset of N transformations from a list of transformations. """Apply a random subset of N transformations from a list of transformations.
@@ -218,6 +219,7 @@ def make_transform_from_config(cfg: ImageTransformConfig):
raise ValueError(f"Transform '{cfg.type}' is not valid.") raise ValueError(f"Transform '{cfg.type}' is not valid.")
# TODO(Steven): Missing transform() implementation
class ImageTransforms(Transform): class ImageTransforms(Transform):
"""A class to compose image transforms based on configuration.""" """A class to compose image transforms based on configuration."""

View File

@@ -135,21 +135,21 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset: def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
# Embed image bytes into the table before saving to parquet # Embed image bytes into the table before saving to parquet
format = dataset.format ds_format = dataset.format
dataset = dataset.with_format("arrow") dataset = dataset.with_format("arrow")
dataset = dataset.map(embed_table_storage, batched=False) dataset = dataset.map(embed_table_storage, batched=False)
dataset = dataset.with_format(**format) dataset = dataset.with_format(**ds_format)
return dataset return dataset
def load_json(fpath: Path) -> Any: def load_json(fpath: Path) -> Any:
with open(fpath) as f: with open(fpath, encoding="utf-8") as f:
return json.load(f) return json.load(f)
def write_json(data: dict, fpath: Path) -> None: def write_json(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True) fpath.parent.mkdir(exist_ok=True, parents=True)
with open(fpath, "w") as f: with open(fpath, "w", encoding="utf-8") as f:
json.dump(data, f, indent=4, ensure_ascii=False) json.dump(data, f, indent=4, ensure_ascii=False)
@@ -300,7 +300,7 @@ def check_version_compatibility(
if v_check.major < v_current.major and enforce_breaking_major: if v_check.major < v_current.major and enforce_breaking_major:
raise BackwardCompatibilityError(repo_id, v_check) raise BackwardCompatibilityError(repo_id, v_check)
elif v_check.minor < v_current.minor: elif v_check.minor < v_current.minor:
logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check)) logging.warning("%s", V21_MESSAGE.format(repo_id=repo_id, version=v_check))
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]: def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
@@ -348,7 +348,9 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
if compatibles: if compatibles:
return_version = max(compatibles) return_version = max(compatibles)
if return_version < target_version: if return_version < target_version:
logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}") logging.warning(
"Revision %s for %s not found, using version v%s", version, repo_id, return_version
)
return f"v{return_version}" return f"v{return_version}"
lower_major = [v for v in hub_versions if v.major < target_version.major] lower_major = [v for v in hub_versions if v.major < target_version.major]
@@ -403,7 +405,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
for key, ft in features.items(): for key, ft in features.items():
shape = ft["shape"] shape = ft["shape"]
if ft["dtype"] in ["image", "video"]: if ft["dtype"] in ["image", "video"]:
type = FeatureType.VISUAL feature_type = FeatureType.VISUAL
if len(shape) != 3: if len(shape) != 3:
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})") raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
@@ -412,16 +414,16 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1]) shape = (shape[2], shape[0], shape[1])
elif key == "observation.environment_state": elif key == "observation.environment_state":
type = FeatureType.ENV feature_type = FeatureType.ENV
elif key.startswith("observation"): elif key.startswith("observation"):
type = FeatureType.STATE feature_type = FeatureType.STATE
elif key == "action": elif key == "action":
type = FeatureType.ACTION feature_type = FeatureType.ACTION
else: else:
continue continue
policy_features[key] = PolicyFeature( policy_features[key] = PolicyFeature(
type=type, type=feature_type,
shape=shape, shape=shape,
) )

View File

@@ -871,11 +871,11 @@ def batch_convert():
try: try:
convert_dataset(repo_id, LOCAL_DIR, **kwargs) convert_dataset(repo_id, LOCAL_DIR, **kwargs)
status = f"{repo_id}: success." status = f"{repo_id}: success."
with open(logfile, "a") as file: with open(logfile, "a", encoding="utf-8") as file:
file.write(status + "\n") file.write(status + "\n")
except Exception: except Exception:
status = f"{repo_id}: failed\n {traceback.format_exc()}" status = f"{repo_id}: failed\n {traceback.format_exc()}"
with open(logfile, "a") as file: with open(logfile, "a", encoding="utf-8") as file:
file.write(status + "\n") file.write(status + "\n")
continue continue

View File

@@ -190,11 +190,11 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
json_path = v2_dir / STATS_PATH json_path = v2_dir / STATS_PATH
json_path.parent.mkdir(exist_ok=True, parents=True) json_path.parent.mkdir(exist_ok=True, parents=True)
with open(json_path, "w") as f: with open(json_path, "w", encoding="utf-8") as f:
json.dump(serialized_stats, f, indent=4) json.dump(serialized_stats, f, indent=4)
# Sanity check # Sanity check
with open(json_path) as f: with open(json_path, encoding="utf-8") as f:
stats_json = json.load(f) stats_json = json.load(f)
stats_json = flatten_dict(stats_json) stats_json = flatten_dict(stats_json)
@@ -213,7 +213,7 @@ def get_features_from_hf_dataset(
dtype = ft.dtype dtype = ft.dtype
shape = (1,) shape = (1,)
names = None names = None
if isinstance(ft, datasets.Sequence): elif isinstance(ft, datasets.Sequence):
assert isinstance(ft.feature, datasets.Value) assert isinstance(ft.feature, datasets.Value)
dtype = ft.feature.dtype dtype = ft.feature.dtype
shape = (ft.length,) shape = (ft.length,)
@@ -232,6 +232,8 @@ def get_features_from_hf_dataset(
dtype = "video" dtype = "video"
shape = None # Add shape later shape = None # Add shape later
names = ["height", "width", "channels"] names = ["height", "width", "channels"]
else:
raise NotImplementedError(f"Feature type {ft._type} not supported.")
features[key] = { features[key] = {
"dtype": dtype, "dtype": dtype,
@@ -358,9 +360,9 @@ def move_videos(
if len(video_dirs) == 1: if len(video_dirs) == 1:
video_path = video_dirs[0] / video_file video_path = video_dirs[0] / video_file
else: else:
for dir in video_dirs: for v_dir in video_dirs:
if (dir / video_file).is_file(): if (v_dir / video_file).is_file():
video_path = dir / video_file video_path = v_dir / video_file
break break
video_path.rename(work_dir / target_path) video_path.rename(work_dir / target_path)
@@ -652,6 +654,7 @@ def main():
if not args.local_dir: if not args.local_dir:
args.local_dir = Path("/tmp/lerobot_dataset_v2") args.local_dir = Path("/tmp/lerobot_dataset_v2")
robot_config = None
if args.robot is not None: if args.robot is not None:
robot_config = make_robot_config(args.robot) robot_config = make_robot_config(args.robot)

View File

@@ -50,7 +50,7 @@ def fix_dataset(repo_id: str) -> str:
return f"{repo_id}: skipped (no diff)" return f"{repo_id}: skipped (no diff)"
if diff_meta_parquet: if diff_meta_parquet:
logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}") logging.warning("In info.json not in parquet: %s", meta_features - parquet_features)
assert diff_meta_parquet == {"language_instruction"} assert diff_meta_parquet == {"language_instruction"}
lerobot_metadata.features.pop("language_instruction") lerobot_metadata.features.pop("language_instruction")
write_info(lerobot_metadata.info, lerobot_metadata.root) write_info(lerobot_metadata.info, lerobot_metadata.root)
@@ -79,7 +79,7 @@ def batch_fix():
status = f"{repo_id}: failed\n {traceback.format_exc()}" status = f"{repo_id}: failed\n {traceback.format_exc()}"
logging.info(status) logging.info(status)
with open(logfile, "a") as file: with open(logfile, "a", encoding="utf-8") as file:
file.write(status + "\n") file.write(status + "\n")

View File

@@ -46,7 +46,7 @@ def batch_convert():
except Exception: except Exception:
status = f"{repo_id}: failed\n {traceback.format_exc()}" status = f"{repo_id}: failed\n {traceback.format_exc()}"
with open(logfile, "a") as file: with open(logfile, "a", encoding="utf-8") as file:
file.write(status + "\n") file.write(status + "\n")

View File

@@ -45,6 +45,9 @@ V21 = "v2.1"
class SuppressWarnings: class SuppressWarnings:
def __init__(self):
self.previous_level = None
def __enter__(self): def __enter__(self):
self.previous_level = logging.getLogger().getEffectiveLevel() self.previous_level = logging.getLogger().getEffectiveLevel()
logging.getLogger().setLevel(logging.ERROR) logging.getLogger().setLevel(logging.ERROR)

View File

@@ -83,7 +83,7 @@ def decode_video_frames_torchvision(
for frame in reader: for frame in reader:
current_ts = frame["pts"] current_ts = frame["pts"]
if log_loaded_timestamps: if log_loaded_timestamps:
logging.info(f"frame loaded at timestamp={current_ts:.4f}") logging.info("frame loaded at timestamp=%.4f", current_ts)
loaded_frames.append(frame["data"]) loaded_frames.append(frame["data"])
loaded_ts.append(current_ts) loaded_ts.append(current_ts)
if current_ts >= last_ts: if current_ts >= last_ts:
@@ -118,7 +118,7 @@ def decode_video_frames_torchvision(
closest_ts = loaded_ts[argmin_] closest_ts = loaded_ts[argmin_]
if log_loaded_timestamps: if log_loaded_timestamps:
logging.info(f"{closest_ts=}") logging.info("closest_ts=%s", closest_ts)
# convert to the pytorch format which is float32 in [0,1] range (and channel first) # convert to the pytorch format which is float32 in [0,1] range (and channel first)
closest_frames = closest_frames.type(torch.float32) / 255 closest_frames = closest_frames.type(torch.float32) / 255
@@ -227,7 +227,9 @@ def get_audio_info(video_path: Path | str) -> dict:
"json", "json",
str(video_path), str(video_path),
] ]
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) result = subprocess.run(
ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True
)
if result.returncode != 0: if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}") raise RuntimeError(f"Error running ffprobe: {result.stderr}")
@@ -263,7 +265,9 @@ def get_video_info(video_path: Path | str) -> dict:
"json", "json",
str(video_path), str(video_path),
] ]
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) result = subprocess.run(
ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True
)
if result.returncode != 0: if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}") raise RuntimeError(f"Error running ffprobe: {result.stderr}")

View File

@@ -32,7 +32,8 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
def type(self) -> str: def type(self) -> str:
return self.get_choice_name(self.__class__) return self.get_choice_name(self.__class__)
@abc.abstractproperty @property
@abc.abstractmethod
def gym_kwargs(self) -> dict: def gym_kwargs(self) -> dict:
raise NotImplementedError() raise NotImplementedError()

View File

@@ -44,7 +44,7 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
return "adam" return "adam"
@abc.abstractmethod @abc.abstractmethod
def build(self) -> torch.optim.Optimizer: def build(self, params: dict) -> torch.optim.Optimizer:
raise NotImplementedError raise NotImplementedError

View File

@@ -140,7 +140,7 @@ class ACTConfig(PreTrainedConfig):
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
"""Input validation (not exhaustive).""" # Input validation (not exhaustive).
if not self.vision_backbone.startswith("resnet"): if not self.vision_backbone.startswith("resnet"):
raise ValueError( raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."

View File

@@ -222,6 +222,8 @@ class ACTTemporalEnsembler:
self.chunk_size = chunk_size self.chunk_size = chunk_size
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)) self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0) self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
self.ensembled_actions = None
self.ensembled_actions_count = None
self.reset() self.reset()
def reset(self): def reset(self):

View File

@@ -162,7 +162,7 @@ class DiffusionConfig(PreTrainedConfig):
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
"""Input validation (not exhaustive).""" # Input validation (not exhaustive).
if not self.vision_backbone.startswith("resnet"): if not self.vision_backbone.startswith("resnet"):
raise ValueError( raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."

View File

@@ -170,6 +170,7 @@ def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMSche
raise ValueError(f"Unsupported noise scheduler type {name}") raise ValueError(f"Unsupported noise scheduler type {name}")
# TODO(Steven): Missing forward() implementation
class DiffusionModel(nn.Module): class DiffusionModel(nn.Module):
def __init__(self, config: DiffusionConfig): def __init__(self, config: DiffusionConfig):
super().__init__() super().__init__()
@@ -203,6 +204,7 @@ class DiffusionModel(nn.Module):
) )
if config.num_inference_steps is None: if config.num_inference_steps is None:
# TODO(Steven): Consider type check?
self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps
else: else:
self.num_inference_steps = config.num_inference_steps self.num_inference_steps = config.num_inference_steps
@@ -333,7 +335,7 @@ class DiffusionModel(nn.Module):
# Sample a random noising timestep for each item in the batch. # Sample a random noising timestep for each item in the batch.
timesteps = torch.randint( timesteps = torch.randint(
low=0, low=0,
high=self.noise_scheduler.config.num_train_timesteps, high=self.noise_scheduler.config.num_train_timesteps, # TODO(Steven): Consider type check?
size=(trajectory.shape[0],), size=(trajectory.shape[0],),
device=trajectory.device, device=trajectory.device,
).long() ).long()

View File

@@ -69,12 +69,12 @@ def create_stats_buffers(
} }
) )
elif norm_mode is NormalizationMode.MIN_MAX: elif norm_mode is NormalizationMode.MIN_MAX:
min = torch.ones(shape, dtype=torch.float32) * torch.inf min_norm = torch.ones(shape, dtype=torch.float32) * torch.inf
max = torch.ones(shape, dtype=torch.float32) * torch.inf max_norm = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict( buffer = nn.ParameterDict(
{ {
"min": nn.Parameter(min, requires_grad=False), "min": nn.Parameter(min_norm, requires_grad=False),
"max": nn.Parameter(max, requires_grad=False), "max": nn.Parameter(max_norm, requires_grad=False),
} }
) )
@@ -170,12 +170,12 @@ class Normalize(nn.Module):
assert not torch.isinf(std).any(), _no_stats_error_str("std") assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = (batch[key] - mean) / (std + 1e-8) batch[key] = (batch[key] - mean) / (std + 1e-8)
elif norm_mode is NormalizationMode.MIN_MAX: elif norm_mode is NormalizationMode.MIN_MAX:
min = buffer["min"] min_norm = buffer["min"]
max = buffer["max"] max_norm = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min") assert not torch.isinf(min_norm).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max") assert not torch.isinf(max_norm).any(), _no_stats_error_str("max")
# normalize to [0,1] # normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min + 1e-8) batch[key] = (batch[key] - min_norm) / (max_norm - min_norm + 1e-8)
# normalize to [-1, 1] # normalize to [-1, 1]
batch[key] = batch[key] * 2 - 1 batch[key] = batch[key] * 2 - 1
else: else:
@@ -243,12 +243,12 @@ class Unnormalize(nn.Module):
assert not torch.isinf(std).any(), _no_stats_error_str("std") assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = batch[key] * std + mean batch[key] = batch[key] * std + mean
elif norm_mode is NormalizationMode.MIN_MAX: elif norm_mode is NormalizationMode.MIN_MAX:
min = buffer["min"] min_norm = buffer["min"]
max = buffer["max"] max_norm = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min") assert not torch.isinf(min_norm).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max") assert not torch.isinf(max_norm).any(), _no_stats_error_str("max")
batch[key] = (batch[key] + 1) / 2 batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min batch[key] = batch[key] * (max_norm - min_norm) + min_norm
else: else:
raise ValueError(norm_mode) raise ValueError(norm_mode)
return batch return batch

View File

@@ -91,7 +91,7 @@ class PI0Config(PreTrainedConfig):
super().__post_init__() super().__post_init__()
# TODO(Steven): Validate device and amp? in all policy configs? # TODO(Steven): Validate device and amp? in all policy configs?
"""Input validation (not exhaustive).""" # Input validation (not exhaustive).
if self.n_action_steps > self.chunk_size: if self.n_action_steps > self.chunk_size:
raise ValueError( raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got " f"The chunk size is the upper bound for the number of action steps per model invocation. Got "

View File

@@ -55,7 +55,7 @@ def main():
with open(save_dir / "noise.pkl", "rb") as f: with open(save_dir / "noise.pkl", "rb") as f:
noise = pickle.load(f) noise = pickle.load(f)
with open(ckpt_jax_dir / "assets/norm_stats.json") as f: with open(ckpt_jax_dir / "assets/norm_stats.json", encoding="utf-8") as f:
norm_stats = json.load(f) norm_stats = json.load(f)
# Override stats # Override stats

View File

@@ -318,7 +318,7 @@ def update_keys_with_prefix(d: dict, prefix: str) -> dict:
return {f"{prefix}{key}": value for key, value in d.items()} return {f"{prefix}{key}": value for key, value in d.items()}
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str): def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, _tokenizer_id: str, output_path: str):
# Break down orbax ckpts - they are in OCDBT # Break down orbax ckpts - they are in OCDBT
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir) initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
# process projection params # process projection params
@@ -432,6 +432,6 @@ if __name__ == "__main__":
convert_pi0_checkpoint( convert_pi0_checkpoint(
checkpoint_dir=args.checkpoint_dir, checkpoint_dir=args.checkpoint_dir,
precision=args.precision, precision=args.precision,
tokenizer_id=args.tokenizer_hub_id, _tokenizer_id=args.tokenizer_hub_id,
output_path=args.output_path, output_path=args.output_path,
) )

View File

@@ -16,6 +16,7 @@ import torch
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
from packaging.version import Version from packaging.version import Version
# TODO(Steven): Consider settings this a dependency constraint
if Version(torch.__version__) > Version("2.5.0"): if Version(torch.__version__) > Version("2.5.0"):
# Ffex attention is only available from torch 2.5 onwards # Ffex attention is only available from torch 2.5 onwards
from torch.nn.attention.flex_attention import ( from torch.nn.attention.flex_attention import (
@@ -121,7 +122,7 @@ def flex_attention_forward(
) )
# mask is applied inside the kernel, ideally more efficiently than score_mod. # mask is applied inside the kernel, ideally more efficiently than score_mod.
attn_output, attention_weights = flex_attention( attn_output, _attention_weights = flex_attention(
query_states, query_states,
key_states, key_states,
value_states, value_states,

View File

@@ -162,7 +162,7 @@ class TDMPCConfig(PreTrainedConfig):
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
"""Input validation (not exhaustive).""" # Input validation (not exhaustive).
if self.n_gaussian_samples <= 0: if self.n_gaussian_samples <= 0:
raise ValueError( raise ValueError(
f"The number of gaussian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`" f"The number of gaussian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"

View File

@@ -88,6 +88,9 @@ class TDMPCPolicy(PreTrainedPolicy):
for param in self.model_target.parameters(): for param in self.model_target.parameters():
param.requires_grad = False param.requires_grad = False
self._queues = None
self._prev_mean: torch.Tensor | None = None
self.reset() self.reset()
def get_optim_params(self) -> dict: def get_optim_params(self) -> dict:
@@ -108,7 +111,7 @@ class TDMPCPolicy(PreTrainedPolicy):
self._queues["observation.environment_state"] = deque(maxlen=1) self._queues["observation.environment_state"] = deque(maxlen=1)
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start # Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
# CEM for the next step. # CEM for the next step.
self._prev_mean: torch.Tensor | None = None self._prev_mean = None
@torch.no_grad() @torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor:
@@ -514,6 +517,7 @@ class TDMPCPolicy(PreTrainedPolicy):
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum) update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
# TODO(Steven): forward implementation missing
class TDMPCTOLD(nn.Module): class TDMPCTOLD(nn.Module):
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC.""" """Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""

View File

@@ -144,7 +144,7 @@ class VQBeTConfig(PreTrainedConfig):
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
"""Input validation (not exhaustive).""" # Input validation (not exhaustive).
if not self.vision_backbone.startswith("resnet"): if not self.vision_backbone.startswith("resnet"):
raise ValueError( raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."

View File

@@ -70,6 +70,8 @@ class VQBeTPolicy(PreTrainedPolicy):
self.vqbet = VQBeTModel(config) self.vqbet = VQBeTModel(config)
self._queues = None
self.reset() self.reset()
def get_optim_params(self) -> dict: def get_optim_params(self) -> dict:
@@ -535,7 +537,7 @@ class VQBeTHead(nn.Module):
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers
) )
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1) cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
NT, G, choices = cbet_probs.shape NT, _G, choices = cbet_probs.shape
sampled_centers = einops.rearrange( sampled_centers = einops.rearrange(
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1), torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
"(NT G) 1 -> NT G", "(NT G) 1 -> NT G",
@@ -578,7 +580,7 @@ class VQBeTHead(nn.Module):
"decoded_action": decoded_action, "decoded_action": decoded_action,
} }
def loss_fn(self, pred, target, **kwargs): def loss_fn(self, pred, target, **_kwargs):
""" """
for given ground truth action values (target), and prediction (pred) this function calculates the overall loss. for given ground truth action values (target), and prediction (pred) this function calculates the overall loss.
@@ -605,7 +607,7 @@ class VQBeTHead(nn.Module):
# Figure out the loss for the actions. # Figure out the loss for the actions.
# First, we need to find the closest cluster center for each ground truth action. # First, we need to find the closest cluster center for each ground truth action.
with torch.no_grad(): with torch.no_grad():
state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G _state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
# Now we can compute the loss. # Now we can compute the loss.
@@ -762,6 +764,7 @@ def _replace_submodules(
return root_module return root_module
# TODO(Steven): Missing implementation of forward, is it maybe vqvae_forward?
class VqVae(nn.Module): class VqVae(nn.Module):
def __init__( def __init__(
self, self,
@@ -876,13 +879,13 @@ class FocalLoss(nn.Module):
self.gamma = gamma self.gamma = gamma
self.size_average = size_average self.size_average = size_average
def forward(self, input, target): def forward(self, forward_input, target):
if len(input.shape) == 3: if len(forward_input.shape) == 3:
N, T, _ = input.shape N, T, _ = forward_input.shape
logpt = F.log_softmax(input, dim=-1) logpt = F.log_softmax(forward_input, dim=-1)
logpt = logpt.gather(-1, target.view(N, T, 1)).view(N, T) logpt = logpt.gather(-1, target.view(N, T, 1)).view(N, T)
elif len(input.shape) == 2: elif len(forward_input.shape) == 2:
logpt = F.log_softmax(input, dim=-1) logpt = F.log_softmax(forward_input, dim=-1)
logpt = logpt.gather(-1, target.view(-1, 1)).view(-1) logpt = logpt.gather(-1, target.view(-1, 1)).view(-1)
pt = logpt.exp() pt = logpt.exp()

View File

@@ -34,63 +34,58 @@ from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
# ruff: noqa: N806 # ruff: noqa: N806
""" # This file is part of a VQ-BeT that utilizes code from the following repositories:
This file is part of a VQ-BeT that utilizes code from the following repositories: #
# - Vector Quantize PyTorch code is licensed under the MIT License:
# Original source: https://github.com/lucidrains/vector-quantize-pytorch
#
# - nanoGPT part is an adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch.
# Original source: https://github.com/karpathy/nanoGPT
#
# We also made some changes to the original code to adapt it to our needs. The changes are described in the code below.
- Vector Quantize PyTorch code is licensed under the MIT License: # This is a part for nanoGPT that utilizes code from the following repository:
Original source: https://github.com/lucidrains/vector-quantize-pytorch #
# - Andrej Karpathy's nanoGPT implementation in PyTorch.
- nanoGPT part is an adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch. # Original source: https://github.com/karpathy/nanoGPT
Original source: https://github.com/karpathy/nanoGPT #
# - The nanoGPT code is licensed under the MIT License:
We also made some changes to the original code to adapt it to our needs. The changes are described in the code below. #
""" # MIT License
#
""" # Copyright (c) 2022 Andrej Karpathy
This is a part for nanoGPT that utilizes code from the following repository: #
# Permission is hereby granted, free of charge, to any person obtaining a copy
- Andrej Karpathy's nanoGPT implementation in PyTorch. # of this software and associated documentation files (the "Software"), to deal
Original source: https://github.com/karpathy/nanoGPT # in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- The nanoGPT code is licensed under the MIT License: # copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
MIT License #
# The above copyright notice and this permission notice shall be included in all
Copyright (c) 2022 Andrej Karpathy # copies or substantial portions of the Software.
#
Permission is hereby granted, free of charge, to any person obtaining a copy # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
of this software and associated documentation files (the "Software"), to deal # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
in the Software without restriction, including without limitation the rights # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
copies of the Software, and to permit persons to whom the Software is # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
furnished to do so, subject to the following conditions: # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
The above copyright notice and this permission notice shall be included in all #
copies or substantial portions of the Software. # - We've made some changes to the original code to adapt it to our needs.
#
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # Changed variable names:
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # - n_head -> gpt_n_head
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # - n_embd -> gpt_hidden_dim
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # - block_size -> gpt_block_size
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # - n_layer -> gpt_n_layer
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE #
SOFTWARE. #
# class GPT(nn.Module):
- We've made some changes to the original code to adapt it to our needs. # - removed unused functions `def generate`, `def estimate_mfu`, and `def from_pretrained`
# - changed the `configure_optimizers` to `def configure_parameters` and made it to return only the parameters of the model: we use an external optimizer in our training loop.
Changed variable names: # - in the function `forward`, we removed target loss calculation parts, since it will be calculated in the training loop (after passing through bin prediction and offset prediction heads).
- n_head -> gpt_n_head
- n_embd -> gpt_hidden_dim
- block_size -> gpt_block_size
- n_layer -> gpt_n_layer
class GPT(nn.Module):
- removed unused functions `def generate`, `def estimate_mfu`, and `def from_pretrained`
- changed the `configure_optimizers` to `def configure_parameters` and made it to return only the parameters of the model: we use an external optimizer in our training loop.
- in the function `forward`, we removed target loss calculation parts, since it will be calculated in the training loop (after passing through bin prediction and offset prediction heads).
"""
class CausalSelfAttention(nn.Module): class CausalSelfAttention(nn.Module):
@@ -200,9 +195,9 @@ class GPT(nn.Module):
n_params = sum(p.numel() for p in self.parameters()) n_params = sum(p.numel() for p in self.parameters())
print("number of parameters: {:.2f}M".format(n_params / 1e6)) print("number of parameters: {:.2f}M".format(n_params / 1e6))
def forward(self, input, targets=None): def forward(self, forward_input):
device = input.device device = forward_input.device
b, t, d = input.size() _, t, _ = forward_input.size()
assert t <= self.config.gpt_block_size, ( assert t <= self.config.gpt_block_size, (
f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}" f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
) )
@@ -211,7 +206,7 @@ class GPT(nn.Module):
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
# forward the GPT model itself # forward the GPT model itself
tok_emb = self.transformer.wte(input) # token embeddings of shape (b, t, gpt_hidden_dim) tok_emb = self.transformer.wte(forward_input) # token embeddings of shape (b, t, gpt_hidden_dim)
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim) pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim)
x = self.transformer.drop(tok_emb + pos_emb) x = self.transformer.drop(tok_emb + pos_emb)
for block in self.transformer.h: for block in self.transformer.h:
@@ -285,51 +280,48 @@ class GPT(nn.Module):
return decay, no_decay return decay, no_decay
""" # This file is a part for Residual Vector Quantization that utilizes code from the following repository:
This file is a part for Residual Vector Quantization that utilizes code from the following repository: #
# - Phil Wang's vector-quantize-pytorch implementation in PyTorch.
- Phil Wang's vector-quantize-pytorch implementation in PyTorch. # Original source: https://github.com/lucidrains/vector-quantize-pytorch
Original source: https://github.com/lucidrains/vector-quantize-pytorch #
# - The vector-quantize-pytorch code is licensed under the MIT License:
- The vector-quantize-pytorch code is licensed under the MIT License: #
# MIT License
MIT License #
# Copyright (c) 2020 Phil Wang
Copyright (c) 2020 Phil Wang #
# Permission is hereby granted, free of charge, to any person obtaining a copy
Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal
of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights
in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is
copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions:
furnished to do so, subject to the following conditions: #
# The above copyright notice and this permission notice shall be included in all
The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software.
copies or substantial portions of the Software. #
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE.
SOFTWARE. #
# - We've made some changes to the original code to adapt it to our needs.
- We've made some changes to the original code to adapt it to our needs. #
# class ResidualVQ(nn.Module):
class ResidualVQ(nn.Module): # - added `self.register_buffer('freeze_codebook', torch.tensor(False))` to the __init__ method:
- added `self.register_buffer('freeze_codebook', torch.tensor(False))` to the __init__ method: # This enables the user to save an indicator whether the codebook is frozen or not.
This enables the user to save an indicator whether the codebook is frozen or not. # - changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`:
- changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`: # This is to make the function name more descriptive.
This is to make the function name more descriptive. #
# class VectorQuantize(nn.Module):
class VectorQuantize(nn.Module): # - removed the `use_cosine_sim` and `layernorm_after_project_in` parameters from the __init__ method:
- removed the `use_cosine_sim` and `layernorm_after_project_in` parameters from the __init__ method: # These parameters are not used in the code.
These parameters are not used in the code. # - changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`:
- changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`: # This is to make the function name more descriptive.
This is to make the function name more descriptive.
"""
class ResidualVQ(nn.Module): class ResidualVQ(nn.Module):
@@ -479,6 +471,9 @@ class ResidualVQ(nn.Module):
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
null_indices = None
null_loss = None
# sample a layer index at which to dropout further residual quantization # sample a layer index at which to dropout further residual quantization
# also prepare null indices and loss # also prepare null indices and loss
@@ -933,7 +928,7 @@ class VectorQuantize(nn.Module):
return quantize, embed_ind, loss return quantize, embed_ind, loss
def noop(*args, **kwargs): def noop(*_args, **_kwargs):
pass pass

View File

@@ -77,9 +77,9 @@ def save_image(img_array, serial_number, frame_index, images_dir):
path = images_dir / f"camera_{serial_number}_frame_{frame_index:06d}.png" path = images_dir / f"camera_{serial_number}_frame_{frame_index:06d}.png"
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
img.save(str(path), quality=100) img.save(str(path), quality=100)
logging.info(f"Saved image: {path}") logging.info("Saved image: %s", path)
except Exception as e: except Exception as e:
logging.error(f"Failed to save image for camera {serial_number} frame {frame_index}: {e}") logging.error("Failed to save image for camera %s frame %s: %s", serial_number, frame_index, e)
def save_images_from_cameras( def save_images_from_cameras(
@@ -447,7 +447,7 @@ class IntelRealSenseCamera:
num_tries += 1 num_tries += 1
time.sleep(1 / self.fps) time.sleep(1 / self.fps)
if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()): if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()):
raise Exception( raise TimeoutError(
"The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called." "The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called."
) )

View File

@@ -45,7 +45,7 @@ from lerobot.common.utils.utils import capture_timestamp_utc
MAX_OPENCV_INDEX = 60 MAX_OPENCV_INDEX = 60
def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]: def find_cameras(max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
cameras = [] cameras = []
if platform.system() == "Linux": if platform.system() == "Linux":
print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports") print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports")

View File

@@ -169,7 +169,8 @@ def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str
return steps return steps
def convert_to_bytes(value, bytes, mock=False): # TODO(Steven): Similar function in feetch.py, should be moved to a common place.
def convert_to_bytes(value, byte, mock=False):
if mock: if mock:
return value return value
@@ -177,16 +178,16 @@ def convert_to_bytes(value, bytes, mock=False):
# Note: No need to convert back into unsigned int, since this byte preprocessing # Note: No need to convert back into unsigned int, since this byte preprocessing
# already handles it for us. # already handles it for us.
if bytes == 1: if byte == 1:
data = [ data = [
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)), dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
] ]
elif bytes == 2: elif byte == 2:
data = [ data = [
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)), dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)), dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
] ]
elif bytes == 4: elif byte == 4:
data = [ data = [
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)), dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)), dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
@@ -196,7 +197,7 @@ def convert_to_bytes(value, bytes, mock=False):
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but " f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
f"{bytes} is provided instead." f"{byte} is provided instead."
) )
return data return data
@@ -228,9 +229,9 @@ def assert_same_address(model_ctrl_table, motor_models, data_name):
all_addr = [] all_addr = []
all_bytes = [] all_bytes = []
for model in motor_models: for model in motor_models:
addr, bytes = model_ctrl_table[model][data_name] addr, byte = model_ctrl_table[model][data_name]
all_addr.append(addr) all_addr.append(addr)
all_bytes.append(bytes) all_bytes.append(byte)
if len(set(all_addr)) != 1: if len(set(all_addr)) != 1:
raise NotImplementedError( raise NotImplementedError(
@@ -576,6 +577,8 @@ class DynamixelMotorsBus:
# (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution # (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution
low_factor = (start_pos - values[i]) / resolution low_factor = (start_pos - values[i]) / resolution
upp_factor = (end_pos - values[i]) / resolution upp_factor = (end_pos - values[i]) / resolution
else:
raise ValueError(f"Unknown calibration mode '{calib_mode}'.")
if not in_range: if not in_range:
# Get first integer between the two bounds # Get first integer between the two bounds
@@ -596,10 +599,15 @@ class DynamixelMotorsBus:
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
else:
raise ValueError(f"Unknown calibration mode '{calib_mode}'.")
logging.warning( logging.warning(
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, " "Auto-correct calibration of motor '%s' by shifting value by {abs(factor)} full turns, "
f"from '{out_of_range_str}' to '{in_range_str}'." "from '%s' to '%s'.",
name,
out_of_range_str,
in_range_str,
) )
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096. # A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
@@ -656,8 +664,8 @@ class DynamixelMotorsBus:
motor_ids = [motor_ids] motor_ids = [motor_ids]
assert_same_address(self.model_ctrl_table, self.motor_models, data_name) assert_same_address(self.model_ctrl_table, self.motor_models, data_name)
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name] addr, byte = self.model_ctrl_table[motor_models[0]][data_name]
group = dxl.GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes) group = dxl.GroupSyncRead(self.port_handler, self.packet_handler, addr, byte)
for idx in motor_ids: for idx in motor_ids:
group.addParam(idx) group.addParam(idx)
@@ -674,7 +682,7 @@ class DynamixelMotorsBus:
values = [] values = []
for idx in motor_ids: for idx in motor_ids:
value = group.getData(idx, addr, bytes) value = group.getData(idx, addr, byte)
values.append(value) values.append(value)
if return_list: if return_list:
@@ -709,13 +717,13 @@ class DynamixelMotorsBus:
models.append(model) models.append(model)
assert_same_address(self.model_ctrl_table, models, data_name) assert_same_address(self.model_ctrl_table, models, data_name)
addr, bytes = self.model_ctrl_table[model][data_name] addr, byte = self.model_ctrl_table[model][data_name]
group_key = get_group_sync_key(data_name, motor_names) group_key = get_group_sync_key(data_name, motor_names)
if data_name not in self.group_readers: if data_name not in self.group_readers:
# create new group reader # create new group reader
self.group_readers[group_key] = dxl.GroupSyncRead( self.group_readers[group_key] = dxl.GroupSyncRead(
self.port_handler, self.packet_handler, addr, bytes self.port_handler, self.packet_handler, addr, byte
) )
for idx in motor_ids: for idx in motor_ids:
self.group_readers[group_key].addParam(idx) self.group_readers[group_key].addParam(idx)
@@ -733,7 +741,7 @@ class DynamixelMotorsBus:
values = [] values = []
for idx in motor_ids: for idx in motor_ids:
value = self.group_readers[group_key].getData(idx, addr, bytes) value = self.group_readers[group_key].getData(idx, addr, byte)
values.append(value) values.append(value)
values = np.array(values) values = np.array(values)
@@ -767,10 +775,10 @@ class DynamixelMotorsBus:
values = [values] values = [values]
assert_same_address(self.model_ctrl_table, motor_models, data_name) assert_same_address(self.model_ctrl_table, motor_models, data_name)
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name] addr, byte = self.model_ctrl_table[motor_models[0]][data_name]
group = dxl.GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes) group = dxl.GroupSyncWrite(self.port_handler, self.packet_handler, addr, byte)
for idx, value in zip(motor_ids, values, strict=True): for idx, value in zip(motor_ids, values, strict=True):
data = convert_to_bytes(value, bytes, self.mock) data = convert_to_bytes(value, byte, self.mock)
group.addParam(idx, data) group.addParam(idx, data)
for _ in range(num_retry): for _ in range(num_retry):
@@ -821,17 +829,17 @@ class DynamixelMotorsBus:
values = values.tolist() values = values.tolist()
assert_same_address(self.model_ctrl_table, models, data_name) assert_same_address(self.model_ctrl_table, models, data_name)
addr, bytes = self.model_ctrl_table[model][data_name] addr, byte = self.model_ctrl_table[model][data_name]
group_key = get_group_sync_key(data_name, motor_names) group_key = get_group_sync_key(data_name, motor_names)
init_group = data_name not in self.group_readers init_group = data_name not in self.group_readers
if init_group: if init_group:
self.group_writers[group_key] = dxl.GroupSyncWrite( self.group_writers[group_key] = dxl.GroupSyncWrite(
self.port_handler, self.packet_handler, addr, bytes self.port_handler, self.packet_handler, addr, byte
) )
for idx, value in zip(motor_ids, values, strict=True): for idx, value in zip(motor_ids, values, strict=True):
data = convert_to_bytes(value, bytes, self.mock) data = convert_to_bytes(value, byte, self.mock)
if init_group: if init_group:
self.group_writers[group_key].addParam(idx, data) self.group_writers[group_key].addParam(idx, data)
else: else:

View File

@@ -148,7 +148,7 @@ def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str
return steps return steps
def convert_to_bytes(value, bytes, mock=False): def convert_to_bytes(value, byte, mock=False):
if mock: if mock:
return value return value
@@ -156,16 +156,16 @@ def convert_to_bytes(value, bytes, mock=False):
# Note: No need to convert back into unsigned int, since this byte preprocessing # Note: No need to convert back into unsigned int, since this byte preprocessing
# already handles it for us. # already handles it for us.
if bytes == 1: if byte == 1:
data = [ data = [
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)), scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
] ]
elif bytes == 2: elif byte == 2:
data = [ data = [
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)), scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
scs.SCS_HIBYTE(scs.SCS_LOWORD(value)), scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
] ]
elif bytes == 4: elif byte == 4:
data = [ data = [
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)), scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
scs.SCS_HIBYTE(scs.SCS_LOWORD(value)), scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
@@ -175,7 +175,7 @@ def convert_to_bytes(value, bytes, mock=False):
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but " f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
f"{bytes} is provided instead." f"{byte} is provided instead."
) )
return data return data
@@ -207,9 +207,9 @@ def assert_same_address(model_ctrl_table, motor_models, data_name):
all_addr = [] all_addr = []
all_bytes = [] all_bytes = []
for model in motor_models: for model in motor_models:
addr, bytes = model_ctrl_table[model][data_name] addr, byte = model_ctrl_table[model][data_name]
all_addr.append(addr) all_addr.append(addr)
all_bytes.append(bytes) all_bytes.append(byte)
if len(set(all_addr)) != 1: if len(set(all_addr)) != 1:
raise NotImplementedError( raise NotImplementedError(
@@ -557,6 +557,8 @@ class FeetechMotorsBus:
# (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution # (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution
low_factor = (start_pos - values[i]) / resolution low_factor = (start_pos - values[i]) / resolution
upp_factor = (end_pos - values[i]) / resolution upp_factor = (end_pos - values[i]) / resolution
else:
raise ValueError(f"Unknown calibration mode {calib_mode}")
if not in_range: if not in_range:
# Get first integer between the two bounds # Get first integer between the two bounds
@@ -577,10 +579,16 @@ class FeetechMotorsBus:
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
else:
raise ValueError(f"Unknown calibration mode {calib_mode}")
logging.warning( logging.warning(
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, " "Auto-correct calibration of motor '%s' by shifting value by %s full turns, "
f"from '{out_of_range_str}' to '{in_range_str}'." "from '%s' to '%s'.",
name,
abs(factor),
out_of_range_str,
in_range_str,
) )
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096. # A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
@@ -674,8 +682,8 @@ class FeetechMotorsBus:
motor_ids = [motor_ids] motor_ids = [motor_ids]
assert_same_address(self.model_ctrl_table, self.motor_models, data_name) assert_same_address(self.model_ctrl_table, self.motor_models, data_name)
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name] addr, byte = self.model_ctrl_table[motor_models[0]][data_name]
group = scs.GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes) group = scs.GroupSyncRead(self.port_handler, self.packet_handler, addr, byte)
for idx in motor_ids: for idx in motor_ids:
group.addParam(idx) group.addParam(idx)
@@ -692,7 +700,7 @@ class FeetechMotorsBus:
values = [] values = []
for idx in motor_ids: for idx in motor_ids:
value = group.getData(idx, addr, bytes) value = group.getData(idx, addr, byte)
values.append(value) values.append(value)
if return_list: if return_list:
@@ -727,7 +735,7 @@ class FeetechMotorsBus:
models.append(model) models.append(model)
assert_same_address(self.model_ctrl_table, models, data_name) assert_same_address(self.model_ctrl_table, models, data_name)
addr, bytes = self.model_ctrl_table[model][data_name] addr, byte = self.model_ctrl_table[model][data_name]
group_key = get_group_sync_key(data_name, motor_names) group_key = get_group_sync_key(data_name, motor_names)
if data_name not in self.group_readers: if data_name not in self.group_readers:
@@ -737,7 +745,7 @@ class FeetechMotorsBus:
# create new group reader # create new group reader
self.group_readers[group_key] = scs.GroupSyncRead( self.group_readers[group_key] = scs.GroupSyncRead(
self.port_handler, self.packet_handler, addr, bytes self.port_handler, self.packet_handler, addr, byte
) )
for idx in motor_ids: for idx in motor_ids:
self.group_readers[group_key].addParam(idx) self.group_readers[group_key].addParam(idx)
@@ -755,7 +763,7 @@ class FeetechMotorsBus:
values = [] values = []
for idx in motor_ids: for idx in motor_ids:
value = self.group_readers[group_key].getData(idx, addr, bytes) value = self.group_readers[group_key].getData(idx, addr, byte)
values.append(value) values.append(value)
values = np.array(values) values = np.array(values)
@@ -792,10 +800,10 @@ class FeetechMotorsBus:
values = [values] values = [values]
assert_same_address(self.model_ctrl_table, motor_models, data_name) assert_same_address(self.model_ctrl_table, motor_models, data_name)
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name] addr, byte = self.model_ctrl_table[motor_models[0]][data_name]
group = scs.GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes) group = scs.GroupSyncWrite(self.port_handler, self.packet_handler, addr, byte)
for idx, value in zip(motor_ids, values, strict=True): for idx, value in zip(motor_ids, values, strict=True):
data = convert_to_bytes(value, bytes, self.mock) data = convert_to_bytes(value, byte, self.mock)
group.addParam(idx, data) group.addParam(idx, data)
for _ in range(num_retry): for _ in range(num_retry):
@@ -846,17 +854,17 @@ class FeetechMotorsBus:
values = values.tolist() values = values.tolist()
assert_same_address(self.model_ctrl_table, models, data_name) assert_same_address(self.model_ctrl_table, models, data_name)
addr, bytes = self.model_ctrl_table[model][data_name] addr, byte = self.model_ctrl_table[model][data_name]
group_key = get_group_sync_key(data_name, motor_names) group_key = get_group_sync_key(data_name, motor_names)
init_group = data_name not in self.group_readers init_group = data_name not in self.group_readers
if init_group: if init_group:
self.group_writers[group_key] = scs.GroupSyncWrite( self.group_writers[group_key] = scs.GroupSyncWrite(
self.port_handler, self.packet_handler, addr, bytes self.port_handler, self.packet_handler, addr, byte
) )
for idx, value in zip(motor_ids, values, strict=True): for idx, value in zip(motor_ids, values, strict=True):
data = convert_to_bytes(value, bytes, self.mock) data = convert_to_bytes(value, byte, self.mock)
if init_group: if init_group:
self.group_writers[group_key].addParam(idx, data) self.group_writers[group_key].addParam(idx, data)
else: else:

View File

@@ -95,6 +95,8 @@ def move_to_calibrate(
while_move_hook=None, while_move_hook=None,
): ):
initial_pos = arm.read("Present_Position", motor_name) initial_pos = arm.read("Present_Position", motor_name)
p_present_pos = None
n_present_pos = None
if positive_first: if positive_first:
p_present_pos = move_until_block( p_present_pos = move_until_block(
@@ -196,7 +198,7 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
calib["wrist_flex"] = move_to_calibrate(arm, "wrist_flex") calib["wrist_flex"] = move_to_calibrate(arm, "wrist_flex")
calib["wrist_flex"] = apply_offset(calib["wrist_flex"], offset=80) calib["wrist_flex"] = apply_offset(calib["wrist_flex"], offset=80)
def in_between_move_hook(): def in_between_move_hook_elbow():
nonlocal arm, calib nonlocal arm, calib
time.sleep(2) time.sleep(2)
ef_pos = arm.read("Present_Position", "elbow_flex") ef_pos = arm.read("Present_Position", "elbow_flex")
@@ -207,14 +209,14 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
print("Calibrate elbow_flex") print("Calibrate elbow_flex")
calib["elbow_flex"] = move_to_calibrate( calib["elbow_flex"] = move_to_calibrate(
arm, "elbow_flex", positive_first=False, in_between_move_hook=in_between_move_hook arm, "elbow_flex", positive_first=False, in_between_move_hook=in_between_move_hook_elbow
) )
calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024) calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024)
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex") arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex")
time.sleep(1) time.sleep(1)
def in_between_move_hook(): def in_between_move_hook_shoulder():
nonlocal arm, calib nonlocal arm, calib
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"], "elbow_flex") arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"], "elbow_flex")
@@ -224,7 +226,7 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
"shoulder_lift", "shoulder_lift",
invert_drive_mode=True, invert_drive_mode=True,
positive_first=False, positive_first=False,
in_between_move_hook=in_between_move_hook, in_between_move_hook=in_between_move_hook_shoulder,
) )
# add an 30 steps as offset to align with body # add an 30 steps as offset to align with body
calib["shoulder_lift"] = apply_offset(calib["shoulder_lift"], offset=1024 - 50) calib["shoulder_lift"] = apply_offset(calib["shoulder_lift"], offset=1024 - 50)

View File

@@ -67,14 +67,14 @@ def calibrate_follower_arm(motors_bus, calib_dir_str):
return return
if calib_file.exists(): if calib_file.exists():
with open(calib_file) as f: with open(calib_file, encoding="utf-8") as f:
calibration = json.load(f) calibration = json.load(f)
print(f"[INFO] Loaded calibration from {calib_file}") print(f"[INFO] Loaded calibration from {calib_file}")
else: else:
print("[INFO] Calibration file not found. Running manual calibration...") print("[INFO] Calibration file not found. Running manual calibration...")
calibration = run_arm_manual_calibration(motors_bus, "lekiwi", "follower_arm", "follower") calibration = run_arm_manual_calibration(motors_bus, "lekiwi", "follower_arm", "follower")
print(f"[INFO] Calibration complete. Saving to {calib_file}") print(f"[INFO] Calibration complete. Saving to {calib_file}")
with open(calib_file, "w") as f: with open(calib_file, "w", encoding="utf-8") as f:
json.dump(calibration, f) json.dump(calibration, f)
try: try:
motors_bus.set_calibration(calibration) motors_bus.set_calibration(calibration)

View File

@@ -47,8 +47,10 @@ def ensure_safe_goal_position(
if not torch.allclose(goal_pos, safe_goal_pos): if not torch.allclose(goal_pos, safe_goal_pos):
logging.warning( logging.warning(
"Relative goal position magnitude had to be clamped to be safe.\n" "Relative goal position magnitude had to be clamped to be safe.\n"
f" requested relative goal position target: {diff}\n" " requested relative goal position target: %s\n"
f" clamped relative goal position target: {safe_diff}" " clamped relative goal position target: %s",
diff,
safe_diff,
) )
return safe_goal_pos return safe_goal_pos
@@ -245,6 +247,8 @@ class ManipulatorRobot:
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
elif self.robot_type in ["so100", "moss", "lekiwi"]: elif self.robot_type in ["so100", "moss", "lekiwi"]:
from lerobot.common.robot_devices.motors.feetech import TorqueMode from lerobot.common.robot_devices.motors.feetech import TorqueMode
else:
raise NotImplementedError(f"Robot type {self.robot_type} is not supported")
# We assume that at connection time, arms are in a rest position, and torque can # We assume that at connection time, arms are in a rest position, and torque can
# be safely disabled to run calibration and/or set robot preset configurations. # be safely disabled to run calibration and/or set robot preset configurations.
@@ -302,7 +306,7 @@ class ManipulatorRobot:
arm_calib_path = self.calibration_dir / f"{arm_id}.json" arm_calib_path = self.calibration_dir / f"{arm_id}.json"
if arm_calib_path.exists(): if arm_calib_path.exists():
with open(arm_calib_path) as f: with open(arm_calib_path, encoding="utf-8") as f:
calibration = json.load(f) calibration = json.load(f)
else: else:
# TODO(rcadene): display a warning in __init__ if calibration file not available # TODO(rcadene): display a warning in __init__ if calibration file not available
@@ -322,7 +326,7 @@ class ManipulatorRobot:
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'") print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
arm_calib_path.parent.mkdir(parents=True, exist_ok=True) arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
with open(arm_calib_path, "w") as f: with open(arm_calib_path, "w", encoding="utf-8") as f:
json.dump(calibration, f) json.dump(calibration, f)
return calibration return calibration

View File

@@ -262,14 +262,14 @@ class MobileManipulator:
arm_calib_path = self.calibration_dir / f"{arm_id}.json" arm_calib_path = self.calibration_dir / f"{arm_id}.json"
if arm_calib_path.exists(): if arm_calib_path.exists():
with open(arm_calib_path) as f: with open(arm_calib_path, encoding="utf-8") as f:
calibration = json.load(f) calibration = json.load(f)
else: else:
print(f"Missing calibration file '{arm_calib_path}'") print(f"Missing calibration file '{arm_calib_path}'")
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type) calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'") print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
arm_calib_path.parent.mkdir(parents=True, exist_ok=True) arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
with open(arm_calib_path, "w") as f: with open(arm_calib_path, "w", encoding="utf-8") as f:
json.dump(calibration, f) json.dump(calibration, f)
return calibration return calibration
@@ -372,6 +372,7 @@ class MobileManipulator:
present_speed = self.last_present_speed present_speed = self.last_present_speed
# TODO(Steven): [WARN] Plenty of general exceptions
except Exception as e: except Exception as e:
print(f"[DEBUG] Error decoding video message: {e}") print(f"[DEBUG] Error decoding video message: {e}")
# If decode fails, fall back to old data # If decode fails, fall back to old data

View File

@@ -68,9 +68,9 @@ class TimeBenchmark(ContextDecorator):
Block took approximately 10.00 milliseconds Block took approximately 10.00 milliseconds
""" """
def __init__(self, print=False): def __init__(self, print_time=False):
self.local = threading.local() self.local = threading.local()
self.print_time = print self.print_time = print_time
def __enter__(self): def __enter__(self):
self.local.start_time = time.perf_counter() self.local.start_time = time.perf_counter()

View File

@@ -46,7 +46,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b
else: else:
# For packages other than "torch", don't attempt the fallback and set as not available # For packages other than "torch", don't attempt the fallback and set as not available
package_exists = False package_exists = False
logging.debug(f"Detected {pkg_name} version: {package_version}") logging.debug("Detected %s version: %s", {pkg_name}, package_version)
if return_version: if return_version:
return package_exists, package_version return package_exists, package_version
else: else:

View File

@@ -27,6 +27,8 @@ class AverageMeter:
def __init__(self, name: str, fmt: str = ":f"): def __init__(self, name: str, fmt: str = ":f"):
self.name = name self.name = name
self.fmt = fmt self.fmt = fmt
self.val = 0.0
self.avg = 0.0
self.reset() self.reset()
def reset(self) -> None: def reset(self) -> None:

View File

@@ -69,7 +69,7 @@ def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
case _: case _:
device = torch.device(try_device) device = torch.device(try_device)
if log: if log:
logging.warning(f"Using custom {try_device} device.") logging.warning("Using custom %s device.", try_device)
return device return device

View File

@@ -86,7 +86,7 @@ class WandBLogger:
resume="must" if cfg.resume else None, resume="must" if cfg.resume else None,
) )
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") logging.info("Track this run --> %s", colored(wandb.run.get_url(), "yellow", attrs=["bold"]))
self._wandb = wandb self._wandb = wandb
def log_policy(self, checkpoint_dir: Path): def log_policy(self, checkpoint_dir: Path):
@@ -108,7 +108,7 @@ class WandBLogger:
for k, v in d.items(): for k, v in d.items():
if not isinstance(v, (int, float, str)): if not isinstance(v, (int, float, str)):
logging.warning( logging.warning(
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.' 'WandB logging of key "%s" was ignored as its type is not handled by this wrapper.', k
) )
continue continue
self._wandb.log({f"{mode}/{k}": v}, step=step) self._wandb.log({f"{mode}/{k}": v}, step=step)

View File

@@ -64,13 +64,14 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
self.pretrained_path = None self.pretrained_path = None
if not self.device or not is_torch_device_available(self.device): if not self.device or not is_torch_device_available(self.device):
auto_device = auto_select_torch_device() auto_device = auto_select_torch_device()
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.") logging.warning("Device '%s' is not available. Switching to '%s'.", self.device, auto_device)
self.device = auto_device.type self.device = auto_device.type
# Automatically deactivate AMP if necessary # Automatically deactivate AMP if necessary
if self.use_amp and not is_amp_available(self.device): if self.use_amp and not is_amp_available(self.device):
logging.warning( logging.warning(
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP." "Automatic Mixed Precision (amp) is not available on device '%s'. Deactivating AMP.",
self.device,
) )
self.use_amp = False self.use_amp = False
@@ -78,15 +79,18 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
def type(self) -> str: def type(self) -> str:
return self.get_choice_name(self.__class__) return self.get_choice_name(self.__class__)
@abc.abstractproperty @property
@abc.abstractmethod
def observation_delta_indices(self) -> list | None: def observation_delta_indices(self) -> list | None:
raise NotImplementedError raise NotImplementedError
@abc.abstractproperty @property
@abc.abstractmethod
def action_delta_indices(self) -> list | None: def action_delta_indices(self) -> list | None:
raise NotImplementedError raise NotImplementedError
@abc.abstractproperty @property
@abc.abstractmethod
def reward_delta_indices(self) -> list | None: def reward_delta_indices(self) -> list | None:
raise NotImplementedError raise NotImplementedError
@@ -128,7 +132,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
return None return None
def _save_pretrained(self, save_directory: Path) -> None: def _save_pretrained(self, save_directory: Path) -> None:
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"): with open(save_directory / CONFIG_NAME, "w", encoding="utf-8") as f, draccus.config_type("json"):
draccus.dump(self, f, indent=4) draccus.dump(self, f, indent=4)
@classmethod @classmethod

View File

@@ -123,7 +123,10 @@ class TrainPipelineConfig(HubMixin):
return draccus.encode(self) return draccus.encode(self)
def _save_pretrained(self, save_directory: Path) -> None: def _save_pretrained(self, save_directory: Path) -> None:
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"): with (
open(save_directory / TRAIN_CONFIG_NAME, "w", encoding="utf-8") as f,
draccus.config_type("json"),
):
draccus.dump(self, f, indent=4) draccus.dump(self, f, indent=4)
@classmethod @classmethod

View File

@@ -90,6 +90,7 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
print("Scanning all baudrates and motor indices") print("Scanning all baudrates and motor indices")
all_baudrates = set(series_baudrate_table.values()) all_baudrates = set(series_baudrate_table.values())
motor_index = -1 # Set the motor index to an out-of-range value. motor_index = -1 # Set the motor index to an out-of-range value.
baudrate = None
for baudrate in all_baudrates: for baudrate in all_baudrates:
motor_bus.set_bus_baudrate(baudrate) motor_bus.set_bus_baudrate(baudrate)

View File

@@ -81,6 +81,7 @@ This might require a sudo permission to allow your terminal to monitor keyboard
**NOTE**: You can resume/continue data recording by running the same data recording command twice. **NOTE**: You can resume/continue data recording by running the same data recording command twice.
""" """
# TODO(Steven): This script should be updated to use the new robot API and the new dataset API.
import argparse import argparse
import importlib import importlib
import logging import logging

View File

@@ -59,7 +59,7 @@ np_version = np.__version__ if HAS_NP else "N/A"
torch_version = torch.__version__ if HAS_TORCH else "N/A" torch_version = torch.__version__ if HAS_TORCH else "N/A"
torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A" torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A" cuda_version = torch.version.cuda if HAS_TORCH and torch.version.cuda is not None else "N/A"
# TODO(aliberts): refactor into an actual command `lerobot env` # TODO(aliberts): refactor into an actual command `lerobot env`

View File

@@ -259,6 +259,10 @@ def eval_policy(
threads = [] # for video saving threads threads = [] # for video saving threads
n_episodes_rendered = 0 # for saving the correct number of videos n_episodes_rendered = 0 # for saving the correct number of videos
video_paths: list[str] = [] # max_episodes_rendered > 0:
ep_frames: list[np.ndarray] = [] # max_episodes_rendered > 0
episode_data: dict | None = None # return_episode_data == True
# Callback for visualization. # Callback for visualization.
def render_frame(env: gym.vector.VectorEnv): def render_frame(env: gym.vector.VectorEnv):
# noqa: B023 # noqa: B023
@@ -271,19 +275,11 @@ def eval_policy(
# Here we must render all frames and discard any we don't need. # Here we must render all frames and discard any we don't need.
ep_frames.append(np.stack(env.call("render")[:n_to_render_now])) ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
if max_episodes_rendered > 0:
video_paths: list[str] = []
if return_episode_data:
episode_data: dict | None = None
# we dont want progress bar when we use slurm, since it clutters the logs # we dont want progress bar when we use slurm, since it clutters the logs
progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm()) progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm())
for batch_ix in progbar: for batch_ix in progbar:
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout # Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
# step. # step.
if max_episodes_rendered > 0:
ep_frames: list[np.ndarray] = []
if start_seed is None: if start_seed is None:
seeds = None seeds = None
@@ -320,13 +316,19 @@ def eval_policy(
else: else:
all_seeds.append(None) all_seeds.append(None)
# FIXME: episode_data is either None or it doesn't exist
if return_episode_data: if return_episode_data:
if episode_data is None:
start_data_index = 0
elif isinstance(episode_data, dict):
start_data_index = episode_data["index"][-1].item() + 1
else:
start_data_index = 0
this_episode_data = _compile_episode_data( this_episode_data = _compile_episode_data(
rollout_data, rollout_data,
done_indices, done_indices,
start_episode_index=batch_ix * env.num_envs, start_episode_index=batch_ix * env.num_envs,
start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)), start_data_index=start_data_index,
fps=env.unwrapped.metadata["render_fps"], fps=env.unwrapped.metadata["render_fps"],
) )
if episode_data is None: if episode_data is None:
@@ -453,6 +455,7 @@ def _compile_episode_data(
return data_dict return data_dict
# TODO(Steven): [WARN] Redefining built-in 'eval'
@parser.wrap() @parser.wrap()
def eval_main(cfg: EvalPipelineConfig): def eval_main(cfg: EvalPipelineConfig):
logging.info(pformat(asdict(cfg))) logging.info(pformat(asdict(cfg)))
@@ -489,7 +492,7 @@ def eval_main(cfg: EvalPipelineConfig):
print(info["aggregated"]) print(info["aggregated"])
# Save info # Save info
with open(Path(cfg.output_dir) / "eval_info.json", "w") as f: with open(Path(cfg.output_dir) / "eval_info.json", "w", encoding="utf-8") as f:
json.dump(info, f, indent=2) json.dump(info, f, indent=2)
env.close() env.close()

View File

@@ -53,6 +53,7 @@ import torch
from huggingface_hub import HfApi from huggingface_hub import HfApi
from safetensors.torch import save_file from safetensors.torch import save_file
# TODO(Steven): #711 Broke this
from lerobot.common.datasets.compute_stats import compute_stats from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
@@ -89,7 +90,7 @@ def save_meta_data(
# save info # save info
info_path = meta_data_dir / "info.json" info_path = meta_data_dir / "info.json"
with open(str(info_path), "w") as f: with open(str(info_path), "w", encoding="utf-8") as f:
json.dump(info, f, indent=4) json.dump(info, f, indent=4)
# save stats # save stats
@@ -120,11 +121,11 @@ def push_dataset_card_to_hub(
repo_id: str, repo_id: str,
revision: str | None, revision: str | None,
tags: list | None = None, tags: list | None = None,
license: str = "apache-2.0", dataset_license: str = "apache-2.0",
**card_kwargs, **card_kwargs,
): ):
"""Creates and pushes a LeRobotDataset Card with appropriate tags to easily find it on the hub.""" """Creates and pushes a LeRobotDataset Card with appropriate tags to easily find it on the hub."""
card = create_lerobot_dataset_card(tags=tags, license=license, **card_kwargs) card = create_lerobot_dataset_card(tags=tags, license=dataset_license, **card_kwargs)
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=revision) card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=revision)
@@ -213,6 +214,7 @@ def push_dataset_to_hub(
encoding, encoding,
) )
# TODO(Steven): This doesn't seem to exist, maybe it was removed/changed recently?
lerobot_dataset = LeRobotDataset.from_preloaded( lerobot_dataset = LeRobotDataset.from_preloaded(
repo_id=repo_id, repo_id=repo_id,
hf_dataset=hf_dataset, hf_dataset=hf_dataset,

View File

@@ -155,12 +155,14 @@ def train(cfg: TrainPipelineConfig):
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
if cfg.env is not None: if cfg.env is not None:
logging.info(f"{cfg.env.task=}") logging.info("cfg.env.task=%s", cfg.env.task)
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})") logging.info("cfg.steps=%s (%s)", cfg.steps, format_big_number(cfg.steps))
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})") logging.info("dataset.num_frames=%s (%s)", dataset.num_frames, format_big_number(dataset.num_frames))
logging.info(f"{dataset.num_episodes=}") logging.info("dataset.num_episodes=%s", dataset.num_episodes)
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") logging.info(
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") "num_learnable_params=%s (%s)", num_learnable_params, format_big_number(num_learnable_params)
)
logging.info("num_total_params=%s (%s)", num_total_params, format_big_number(num_total_params))
# create dataloader for offline training # create dataloader for offline training
if hasattr(cfg.policy, "drop_n_last_frames"): if hasattr(cfg.policy, "drop_n_last_frames"):
@@ -238,7 +240,7 @@ def train(cfg: TrainPipelineConfig):
train_tracker.reset_averages() train_tracker.reset_averages()
if cfg.save_checkpoint and is_saving_step: if cfg.save_checkpoint and is_saving_step:
logging.info(f"Checkpoint policy after step {step}") logging.info("Checkpoint policy after step %s", step)
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler) save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
update_last_checkpoint(checkpoint_dir) update_last_checkpoint(checkpoint_dir)
@@ -247,7 +249,7 @@ def train(cfg: TrainPipelineConfig):
if cfg.env and is_eval_step: if cfg.env and is_eval_step:
step_id = get_step_identifier(step, cfg.steps) step_id = get_step_identifier(step, cfg.steps)
logging.info(f"Eval policy at step {step}") logging.info("Eval policy at step %s", step)
with ( with (
torch.no_grad(), torch.no_grad(),
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),

View File

@@ -150,7 +150,7 @@ def run_server(
400, 400,
) )
dataset_version = ( dataset_version = (
str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version str(dataset.meta.version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
) )
match = re.search(r"v(\d+)\.", dataset_version) match = re.search(r"v(\d+)\.", dataset_version)
if match: if match:
@@ -358,7 +358,7 @@ def visualize_dataset_html(
if force_override: if force_override:
shutil.rmtree(output_dir) shutil.rmtree(output_dir)
else: else:
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'") logging.info("Output directory already exists. Loading from it: '%s'", {output_dir})
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)

View File

@@ -52,16 +52,16 @@ def get_task_index(task_dicts: dict, task: str) -> int:
return task_to_task_index[task] return task_to_task_index[task]
@pytest.fixture(scope="session") @pytest.fixture(name="img_tensor_factory", scope="session")
def img_tensor_factory(): def fixture_img_tensor_factory():
def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor: def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor:
return torch.rand((channels, height, width), dtype=dtype) return torch.rand((channels, height, width), dtype=dtype)
return _create_img_tensor return _create_img_tensor
@pytest.fixture(scope="session") @pytest.fixture(name="img_array_factory", scope="session")
def img_array_factory(): def fixture_img_array_factory():
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray: def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray:
if np.issubdtype(dtype, np.unsignedinteger): if np.issubdtype(dtype, np.unsignedinteger):
# Int array in [0, 255] range # Int array in [0, 255] range
@@ -76,8 +76,8 @@ def img_array_factory():
return _create_img_array return _create_img_array
@pytest.fixture(scope="session") @pytest.fixture(name="img_factory", scope="session")
def img_factory(img_array_factory): def fixture_img_factory(img_array_factory):
def _create_img(height=100, width=100) -> PIL.Image.Image: def _create_img(height=100, width=100) -> PIL.Image.Image:
img_array = img_array_factory(height=height, width=width) img_array = img_array_factory(height=height, width=width)
return PIL.Image.fromarray(img_array) return PIL.Image.fromarray(img_array)
@@ -85,13 +85,17 @@ def img_factory(img_array_factory):
return _create_img return _create_img
@pytest.fixture(scope="session") @pytest.fixture(name="features_factory", scope="session")
def features_factory(): def fixture_features_factory():
def _create_features( def _create_features(
motor_features: dict = DUMMY_MOTOR_FEATURES, motor_features: dict | None = None,
camera_features: dict = DUMMY_CAMERA_FEATURES, camera_features: dict | None = None,
use_videos: bool = True, use_videos: bool = True,
) -> dict: ) -> dict:
if motor_features is None:
motor_features = DUMMY_MOTOR_FEATURES
if camera_features is None:
camera_features = DUMMY_CAMERA_FEATURES
if use_videos: if use_videos:
camera_ft = { camera_ft = {
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items() key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items()
@@ -107,8 +111,8 @@ def features_factory():
return _create_features return _create_features
@pytest.fixture(scope="session") @pytest.fixture(name="info_factory", scope="session")
def info_factory(features_factory): def fixture_info_factory(features_factory):
def _create_info( def _create_info(
codebase_version: str = CODEBASE_VERSION, codebase_version: str = CODEBASE_VERSION,
fps: int = DEFAULT_FPS, fps: int = DEFAULT_FPS,
@@ -121,10 +125,14 @@ def info_factory(features_factory):
chunks_size: int = DEFAULT_CHUNK_SIZE, chunks_size: int = DEFAULT_CHUNK_SIZE,
data_path: str = DEFAULT_PARQUET_PATH, data_path: str = DEFAULT_PARQUET_PATH,
video_path: str = DEFAULT_VIDEO_PATH, video_path: str = DEFAULT_VIDEO_PATH,
motor_features: dict = DUMMY_MOTOR_FEATURES, motor_features: dict | None = None,
camera_features: dict = DUMMY_CAMERA_FEATURES, camera_features: dict | None = None,
use_videos: bool = True, use_videos: bool = True,
) -> dict: ) -> dict:
if motor_features is None:
motor_features = DUMMY_MOTOR_FEATURES
if camera_features is None:
camera_features = DUMMY_CAMERA_FEATURES
features = features_factory(motor_features, camera_features, use_videos) features = features_factory(motor_features, camera_features, use_videos)
return { return {
"codebase_version": codebase_version, "codebase_version": codebase_version,
@@ -145,8 +153,8 @@ def info_factory(features_factory):
return _create_info return _create_info
@pytest.fixture(scope="session") @pytest.fixture(name="stats_factory", scope="session")
def stats_factory(): def fixture_stats_factory():
def _create_stats( def _create_stats(
features: dict[str] | None = None, features: dict[str] | None = None,
) -> dict: ) -> dict:
@@ -175,8 +183,8 @@ def stats_factory():
return _create_stats return _create_stats
@pytest.fixture(scope="session") @pytest.fixture(name="episodes_stats_factory", scope="session")
def episodes_stats_factory(stats_factory): def fixture_episodes_stats_factory(stats_factory):
def _create_episodes_stats( def _create_episodes_stats(
features: dict[str], features: dict[str],
total_episodes: int = 3, total_episodes: int = 3,
@@ -192,8 +200,8 @@ def episodes_stats_factory(stats_factory):
return _create_episodes_stats return _create_episodes_stats
@pytest.fixture(scope="session") @pytest.fixture(name="tasks_factory", scope="session")
def tasks_factory(): def fixture_tasks_factory():
def _create_tasks(total_tasks: int = 3) -> int: def _create_tasks(total_tasks: int = 3) -> int:
tasks = {} tasks = {}
for task_index in range(total_tasks): for task_index in range(total_tasks):
@@ -204,8 +212,8 @@ def tasks_factory():
return _create_tasks return _create_tasks
@pytest.fixture(scope="session") @pytest.fixture(name="episodes_factory", scope="session")
def episodes_factory(tasks_factory): def fixture_episodes_factory(tasks_factory):
def _create_episodes( def _create_episodes(
total_episodes: int = 3, total_episodes: int = 3,
total_frames: int = 400, total_frames: int = 400,
@@ -252,8 +260,8 @@ def episodes_factory(tasks_factory):
return _create_episodes return _create_episodes
@pytest.fixture(scope="session") @pytest.fixture(name="hf_dataset_factory", scope="session")
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory): def fixture_hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
def _create_hf_dataset( def _create_hf_dataset(
features: dict | None = None, features: dict | None = None,
tasks: list[dict] | None = None, tasks: list[dict] | None = None,
@@ -310,8 +318,8 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
return _create_hf_dataset return _create_hf_dataset
@pytest.fixture(scope="session") @pytest.fixture(name="lerobot_dataset_metadata_factory", scope="session")
def lerobot_dataset_metadata_factory( def fixture_lerobot_dataset_metadata_factory(
info_factory, info_factory,
stats_factory, stats_factory,
episodes_stats_factory, episodes_stats_factory,
@@ -364,8 +372,8 @@ def lerobot_dataset_metadata_factory(
return _create_lerobot_dataset_metadata return _create_lerobot_dataset_metadata
@pytest.fixture(scope="session") @pytest.fixture(name="lerobot_dataset_factory", scope="session")
def lerobot_dataset_factory( def fixture_lerobot_dataset_factory(
info_factory, info_factory,
stats_factory, stats_factory,
episodes_stats_factory, episodes_stats_factory,
@@ -443,6 +451,6 @@ def lerobot_dataset_factory(
return _create_lerobot_dataset return _create_lerobot_dataset
@pytest.fixture(scope="session") @pytest.fixture(name="empty_lerobot_dataset_factory", scope="session")
def empty_lerobot_dataset_factory() -> LeRobotDatasetFactory: def fixture_empty_lerobot_dataset_factory() -> LeRobotDatasetFactory:
return partial(LeRobotDataset.create, repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS) return partial(LeRobotDataset.create, repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS)

View File

@@ -31,12 +31,12 @@ from lerobot.common.datasets.utils import (
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def info_path(info_factory): def info_path(info_factory):
def _create_info_json_file(dir: Path, info: dict | None = None) -> Path: def _create_info_json_file(input_dir: Path, info: dict | None = None) -> Path:
if not info: if not info:
info = info_factory() info = info_factory()
fpath = dir / INFO_PATH fpath = input_dir / INFO_PATH
fpath.parent.mkdir(parents=True, exist_ok=True) fpath.parent.mkdir(parents=True, exist_ok=True)
with open(fpath, "w") as f: with open(fpath, "w", encoding="utf-8") as f:
json.dump(info, f, indent=4, ensure_ascii=False) json.dump(info, f, indent=4, ensure_ascii=False)
return fpath return fpath
@@ -45,12 +45,12 @@ def info_path(info_factory):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def stats_path(stats_factory): def stats_path(stats_factory):
def _create_stats_json_file(dir: Path, stats: dict | None = None) -> Path: def _create_stats_json_file(input_dir: Path, stats: dict | None = None) -> Path:
if not stats: if not stats:
stats = stats_factory() stats = stats_factory()
fpath = dir / STATS_PATH fpath = input_dir / STATS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True) fpath.parent.mkdir(parents=True, exist_ok=True)
with open(fpath, "w") as f: with open(fpath, "w", encoding="utf-8") as f:
json.dump(stats, f, indent=4, ensure_ascii=False) json.dump(stats, f, indent=4, ensure_ascii=False)
return fpath return fpath
@@ -59,10 +59,10 @@ def stats_path(stats_factory):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def episodes_stats_path(episodes_stats_factory): def episodes_stats_path(episodes_stats_factory):
def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path: def _create_episodes_stats_jsonl_file(input_dir: Path, episodes_stats: list[dict] | None = None) -> Path:
if not episodes_stats: if not episodes_stats:
episodes_stats = episodes_stats_factory() episodes_stats = episodes_stats_factory()
fpath = dir / EPISODES_STATS_PATH fpath = input_dir / EPISODES_STATS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True) fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer: with jsonlines.open(fpath, "w") as writer:
writer.write_all(episodes_stats.values()) writer.write_all(episodes_stats.values())
@@ -73,10 +73,10 @@ def episodes_stats_path(episodes_stats_factory):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def tasks_path(tasks_factory): def tasks_path(tasks_factory):
def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path: def _create_tasks_jsonl_file(input_dir: Path, tasks: list | None = None) -> Path:
if not tasks: if not tasks:
tasks = tasks_factory() tasks = tasks_factory()
fpath = dir / TASKS_PATH fpath = input_dir / TASKS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True) fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer: with jsonlines.open(fpath, "w") as writer:
writer.write_all(tasks.values()) writer.write_all(tasks.values())
@@ -87,10 +87,10 @@ def tasks_path(tasks_factory):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def episode_path(episodes_factory): def episode_path(episodes_factory):
def _create_episodes_jsonl_file(dir: Path, episodes: list | None = None) -> Path: def _create_episodes_jsonl_file(input_dir: Path, episodes: list | None = None) -> Path:
if not episodes: if not episodes:
episodes = episodes_factory() episodes = episodes_factory()
fpath = dir / EPISODES_PATH fpath = input_dir / EPISODES_PATH
fpath.parent.mkdir(parents=True, exist_ok=True) fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer: with jsonlines.open(fpath, "w") as writer:
writer.write_all(episodes.values()) writer.write_all(episodes.values())
@@ -102,7 +102,7 @@ def episode_path(episodes_factory):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def single_episode_parquet_path(hf_dataset_factory, info_factory): def single_episode_parquet_path(hf_dataset_factory, info_factory):
def _create_single_episode_parquet( def _create_single_episode_parquet(
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None input_dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path: ) -> Path:
if not info: if not info:
info = info_factory() info = info_factory()
@@ -112,7 +112,7 @@ def single_episode_parquet_path(hf_dataset_factory, info_factory):
data_path = info["data_path"] data_path = info["data_path"]
chunks_size = info["chunks_size"] chunks_size = info["chunks_size"]
ep_chunk = ep_idx // chunks_size ep_chunk = ep_idx // chunks_size
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx) fpath = input_dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
fpath.parent.mkdir(parents=True, exist_ok=True) fpath.parent.mkdir(parents=True, exist_ok=True)
table = hf_dataset.data.table table = hf_dataset.data.table
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
@@ -125,7 +125,7 @@ def single_episode_parquet_path(hf_dataset_factory, info_factory):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def multi_episode_parquet_path(hf_dataset_factory, info_factory): def multi_episode_parquet_path(hf_dataset_factory, info_factory):
def _create_multi_episode_parquet( def _create_multi_episode_parquet(
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None input_dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path: ) -> Path:
if not info: if not info:
info = info_factory() info = info_factory()
@@ -137,11 +137,11 @@ def multi_episode_parquet_path(hf_dataset_factory, info_factory):
total_episodes = info["total_episodes"] total_episodes = info["total_episodes"]
for ep_idx in range(total_episodes): for ep_idx in range(total_episodes):
ep_chunk = ep_idx // chunks_size ep_chunk = ep_idx // chunks_size
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx) fpath = input_dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
fpath.parent.mkdir(parents=True, exist_ok=True) fpath.parent.mkdir(parents=True, exist_ok=True)
table = hf_dataset.data.table table = hf_dataset.data.table
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
pq.write_table(ep_table, fpath) pq.write_table(ep_table, fpath)
return dir / "data" return input_dir / "data"
return _create_multi_episode_parquet return _create_multi_episode_parquet

View File

@@ -81,12 +81,12 @@ def mock_snapshot_download_factory(
return None return None
def _mock_snapshot_download( def _mock_snapshot_download(
repo_id: str, _repo_id: str,
*_args,
local_dir: str | Path | None = None, local_dir: str | Path | None = None,
allow_patterns: str | list[str] | None = None, allow_patterns: str | list[str] | None = None,
ignore_patterns: str | list[str] | None = None, ignore_patterns: str | list[str] | None = None,
*args, **_kwargs,
**kwargs,
) -> str: ) -> str:
if not local_dir: if not local_dir:
local_dir = LEROBOT_TEST_DIR local_dir = LEROBOT_TEST_DIR

View File

@@ -18,13 +18,13 @@ from lerobot.common.optim.optimizers import AdamConfig
from lerobot.common.optim.schedulers import VQBeTSchedulerConfig from lerobot.common.optim.schedulers import VQBeTSchedulerConfig
@pytest.fixture @pytest.fixture(name="model_params")
def model_params(): def fixture_model_params():
return [torch.nn.Parameter(torch.randn(10, 10))] return [torch.nn.Parameter(torch.randn(10, 10))]
@pytest.fixture @pytest.fixture(name="optimizer")
def optimizer(model_params): def fixture_optimizer(model_params):
optimizer = AdamConfig().build(model_params) optimizer = AdamConfig().build(model_params)
# Dummy step to populate state # Dummy step to populate state
loss = sum(param.sum() for param in model_params) loss = sum(param.sum() for param in model_params)
@@ -33,7 +33,7 @@ def optimizer(model_params):
return optimizer return optimizer
@pytest.fixture @pytest.fixture(name="scheduler")
def scheduler(optimizer): def fixture_scheduler(optimizer):
config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5) config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5)
return config.build(optimizer, num_training_steps=100) return config.build(optimizer, num_training_steps=100)

View File

@@ -22,6 +22,8 @@ from lerobot.common.datasets.utils import create_lerobot_dataset_card
def test_default_parameters(): def test_default_parameters():
card = create_lerobot_dataset_card() card = create_lerobot_dataset_card()
assert isinstance(card, DatasetCard) assert isinstance(card, DatasetCard)
# TODO(Steven): Base class CardDate should have 'tags' as a member if we want RepoCard to hold a reference to this abstraction
# card.data gives a CardDate type, implementations of this class do have 'tags' but the base class doesn't
assert card.data.tags == ["LeRobot"] assert card.data.tags == ["LeRobot"]
assert card.data.task_categories == ["robotics"] assert card.data.task_categories == ["robotics"]
assert card.data.configs == [ assert card.data.configs == [

View File

@@ -57,7 +57,7 @@ def rotate(color_image, rotation):
class VideoCapture: class VideoCapture:
def __init__(self, *args, **kwargs): def __init__(self, *_args, **_kwargs):
self._mock_dict = { self._mock_dict = {
CAP_PROP_FPS: 30, CAP_PROP_FPS: 30,
CAP_PROP_FRAME_WIDTH: 640, CAP_PROP_FRAME_WIDTH: 640,

View File

@@ -24,10 +24,9 @@ DEFAULT_BAUDRATE = 9_600
COMM_SUCCESS = 0 # tx or rx packet communication success COMM_SUCCESS = 0 # tx or rx packet communication success
def convert_to_bytes(value, bytes): def convert_to_bytes(value, _byte):
# TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform # TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform
# `convert_bytes_to_value` # `convert_bytes_to_value`
del bytes # unused
return value return value
@@ -74,7 +73,7 @@ class PacketHandler:
class GroupSyncRead: class GroupSyncRead:
def __init__(self, port_handler, packet_handler, address, bytes): def __init__(self, _port_handler, packet_handler, _address, _byte):
self.packet_handler = packet_handler self.packet_handler = packet_handler
def addParam(self, motor_index): # noqa: N802 def addParam(self, motor_index): # noqa: N802
@@ -85,12 +84,12 @@ class GroupSyncRead:
def txRxPacket(self): # noqa: N802 def txRxPacket(self): # noqa: N802
return COMM_SUCCESS return COMM_SUCCESS
def getData(self, index, address, bytes): # noqa: N802 def getData(self, index, address, _byte): # noqa: N802
return self.packet_handler.data[index][address] return self.packet_handler.data[index][address]
class GroupSyncWrite: class GroupSyncWrite:
def __init__(self, port_handler, packet_handler, address, bytes): def __init__(self, _port_handler, packet_handler, address, _byte):
self.packet_handler = packet_handler self.packet_handler = packet_handler
self.address = address self.address = address

View File

@@ -27,6 +27,13 @@ class format(enum.Enum): # noqa: N801
class config: # noqa: N801 class config: # noqa: N801
device_enabled = None
stream_type = None
width = None
height = None
color_format = None
fps = None
def enable_device(self, device_id: str): def enable_device(self, device_id: str):
self.device_enabled = device_id self.device_enabled = device_id
@@ -125,8 +132,7 @@ class RSDevice:
def __init__(self): def __init__(self):
pass pass
def get_info(self, camera_info) -> str: def get_info(self, _camera_info) -> str:
del camera_info # unused
# return fake serial number # return fake serial number
return "123456789" return "123456789"
@@ -145,4 +151,3 @@ class camera_info: # noqa: N801
def __init__(self, serial_number): def __init__(self, serial_number):
del serial_number del serial_number
pass

View File

@@ -24,10 +24,10 @@ DEFAULT_BAUDRATE = 1_000_000
COMM_SUCCESS = 0 # tx or rx packet communication success COMM_SUCCESS = 0 # tx or rx packet communication success
def convert_to_bytes(value, bytes): def convert_to_bytes(value, byte):
# TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform # TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform
# `convert_bytes_to_value` # `convert_bytes_to_value`
del bytes # unused del byte # unused
return value return value
@@ -85,7 +85,7 @@ class PacketHandler:
class GroupSyncRead: class GroupSyncRead:
def __init__(self, port_handler, packet_handler, address, bytes): def __init__(self, _port_handler, packet_handler, _address, _byte):
self.packet_handler = packet_handler self.packet_handler = packet_handler
def addParam(self, motor_index): # noqa: N802 def addParam(self, motor_index): # noqa: N802
@@ -96,12 +96,12 @@ class GroupSyncRead:
def txRxPacket(self): # noqa: N802 def txRxPacket(self): # noqa: N802
return COMM_SUCCESS return COMM_SUCCESS
def getData(self, index, address, bytes): # noqa: N802 def getData(self, index, address, _byte): # noqa: N802
return self.packet_handler.data[index][address] return self.packet_handler.data[index][address]
class GroupSyncWrite: class GroupSyncWrite:
def __init__(self, port_handler, packet_handler, address, bytes): def __init__(self, _port_handler, packet_handler, address, _byte):
self.packet_handler = packet_handler self.packet_handler = packet_handler
self.address = address self.address = address

View File

@@ -81,11 +81,11 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
if __name__ == "__main__": if __name__ == "__main__":
for dataset in [ for available_dataset in [
"lerobot/pusht", "lerobot/pusht",
"lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_insertion_human",
"lerobot/xarm_lift_medium", "lerobot/xarm_lift_medium",
"lerobot/nyu_franka_play_dataset", "lerobot/nyu_franka_play_dataset",
"lerobot/cmu_stretch", "lerobot/cmu_stretch",
]: ]:
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors", repo_id=dataset) save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors", repo_id=available_dataset)

View File

@@ -51,7 +51,7 @@ def save_single_transforms(original_frame: torch.Tensor, output_dir: Path):
} }
frames = {"original_frame": original_frame} frames = {"original_frame": original_frame}
for tf_type, tf_name, min_max_values in transforms.items(): for tf_type, tf_name, min_max_values in transforms:
for min_max in min_max_values: for min_max in min_max_values:
tf_cfg = ImageTransformConfig(type=tf_type, kwargs={tf_name: min_max}) tf_cfg = ImageTransformConfig(type=tf_type, kwargs={tf_name: min_max})
tf = make_transform_from_config(tf_cfg) tf = make_transform_from_config(tf_cfg)

View File

@@ -150,6 +150,7 @@ def test_camera(request, camera_type, mock):
else: else:
import cv2 import cv2
manual_rot_img: np.ndarray = None
if rotation is None: if rotation is None:
manual_rot_img = ori_color_image manual_rot_img = ori_color_image
assert camera.rotation is None assert camera.rotation is None
@@ -197,10 +198,14 @@ def test_camera(request, camera_type, mock):
@require_camera @require_camera
def test_save_images_from_cameras(tmp_path, request, camera_type, mock): def test_save_images_from_cameras(tmp_path, request, camera_type, mock):
# TODO(rcadene): refactor # TODO(rcadene): refactor
save_images_from_cameras = None
if camera_type == "opencv": if camera_type == "opencv":
from lerobot.common.robot_devices.cameras.opencv import save_images_from_cameras from lerobot.common.robot_devices.cameras.opencv import save_images_from_cameras
elif camera_type == "intelrealsense": elif camera_type == "intelrealsense":
from lerobot.common.robot_devices.cameras.intelrealsense import save_images_from_cameras from lerobot.common.robot_devices.cameras.intelrealsense import save_images_from_cameras
else:
raise ValueError(f"Unsupported camera type: {camera_type}")
# Small `record_time_s` to speedup unit tests # Small `record_time_s` to speedup unit tests
save_images_from_cameras(tmp_path, record_time_s=0.02, mock=mock) save_images_from_cameras(tmp_path, record_time_s=0.02, mock=mock)

View File

@@ -30,12 +30,12 @@ from lerobot.common.datasets.compute_stats import (
) )
def mock_load_image_as_numpy(path, dtype, channel_first): def mock_load_image_as_numpy(_path, dtype, channel_first):
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype) return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
@pytest.fixture @pytest.fixture(name="sample_array")
def sample_array(): def fixture_sample_array():
return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
@@ -62,7 +62,7 @@ def test_sample_indices():
@patch("lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy) @patch("lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy)
def test_sample_images(mock_load): def test_sample_images(_mock_load):
image_paths = [f"image_{i}.jpg" for i in range(100)] image_paths = [f"image_{i}.jpg" for i in range(100)]
images = sample_images(image_paths) images = sample_images(image_paths)
assert isinstance(images, np.ndarray) assert isinstance(images, np.ndarray)

View File

@@ -48,8 +48,8 @@ from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
from tests.utils import require_x86_64_kernel from tests.utils import require_x86_64_kernel
@pytest.fixture @pytest.fixture(name="image_dataset")
def image_dataset(tmp_path, empty_lerobot_dataset_factory): def fixture_image_dataset(tmp_path, empty_lerobot_dataset_factory):
features = { features = {
"image": { "image": {
"dtype": "image", "dtype": "image",
@@ -374,7 +374,7 @@ def test_factory(env_name, repo_id, policy_name):
if required: if required:
assert key in item, f"{key}" assert key in item, f"{key}"
else: else:
logging.warning(f'Missing key in dataset: "{key}" not in {dataset}.') logging.warning('Missing key in dataset: "%s" not in %s.', key, dataset)
continue continue
if delta_timestamps is not None and key in delta_timestamps: if delta_timestamps is not None and key in delta_timestamps:

View File

@@ -42,7 +42,9 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.n
table = hf_dataset.data.table table = hf_dataset.data.table
total_episodes = calculate_total_episode(hf_dataset) total_episodes = calculate_total_episode(hf_dataset)
for ep_idx in range(total_episodes): for ep_idx in range(total_episodes):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) ep_table = table.filter(
pc.equal(table["episode_index"], ep_idx)
) # TODO(Steven): What is this check supposed to do?
episode_lengths.insert(ep_idx, len(ep_table)) episode_lengths.insert(ep_idx, len(ep_table))
cumulative_lengths = list(accumulate(episode_lengths)) cumulative_lengths = list(accumulate(episode_lengths))
@@ -52,8 +54,8 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.n
} }
@pytest.fixture(scope="module") @pytest.fixture(name="synced_timestamps_factory", scope="module")
def synced_timestamps_factory(hf_dataset_factory): def fixture_synced_timestamps_factory(hf_dataset_factory):
def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]: def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
hf_dataset = hf_dataset_factory(fps=fps) hf_dataset = hf_dataset_factory(fps=fps)
timestamps = torch.stack(hf_dataset["timestamp"]).numpy() timestamps = torch.stack(hf_dataset["timestamp"]).numpy()
@@ -64,8 +66,8 @@ def synced_timestamps_factory(hf_dataset_factory):
return _create_synced_timestamps return _create_synced_timestamps
@pytest.fixture(scope="module") @pytest.fixture(name="unsynced_timestamps_factory", scope="module")
def unsynced_timestamps_factory(synced_timestamps_factory): def fixture_unsynced_timestamps_factory(synced_timestamps_factory):
def _create_unsynced_timestamps( def _create_unsynced_timestamps(
fps: int = 30, tolerance_s: float = 1e-4 fps: int = 30, tolerance_s: float = 1e-4
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
@@ -76,8 +78,8 @@ def unsynced_timestamps_factory(synced_timestamps_factory):
return _create_unsynced_timestamps return _create_unsynced_timestamps
@pytest.fixture(scope="module") @pytest.fixture(name="slightly_off_timestamps_factory", scope="module")
def slightly_off_timestamps_factory(synced_timestamps_factory): def fixture_slightly_off_timestamps_factory(synced_timestamps_factory):
def _create_slightly_off_timestamps( def _create_slightly_off_timestamps(
fps: int = 30, tolerance_s: float = 1e-4 fps: int = 30, tolerance_s: float = 1e-4
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
@@ -88,22 +90,26 @@ def slightly_off_timestamps_factory(synced_timestamps_factory):
return _create_slightly_off_timestamps return _create_slightly_off_timestamps
@pytest.fixture(scope="module") @pytest.fixture(name="valid_delta_timestamps_factory", scope="module")
def valid_delta_timestamps_factory(): def fixture_valid_delta_timestamps_factory():
def _create_valid_delta_timestamps( def _create_valid_delta_timestamps(
fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10) fps: int = 30, keys: list | None = None, min_max_range: tuple[int, int] = (-10, 10)
) -> dict: ) -> dict:
if keys is None:
keys = DUMMY_MOTOR_FEATURES
delta_timestamps = {key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys} delta_timestamps = {key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys}
return delta_timestamps return delta_timestamps
return _create_valid_delta_timestamps return _create_valid_delta_timestamps
@pytest.fixture(scope="module") @pytest.fixture(name="invalid_delta_timestamps_factory", scope="module")
def invalid_delta_timestamps_factory(valid_delta_timestamps_factory): def fixture_invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
def _create_invalid_delta_timestamps( def _create_invalid_delta_timestamps(
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None
) -> dict: ) -> dict:
if keys is None:
keys = DUMMY_MOTOR_FEATURES
delta_timestamps = valid_delta_timestamps_factory(fps, keys) delta_timestamps = valid_delta_timestamps_factory(fps, keys)
# Modify a single timestamp just outside tolerance # Modify a single timestamp just outside tolerance
for key in keys: for key in keys:
@@ -113,11 +119,13 @@ def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
return _create_invalid_delta_timestamps return _create_invalid_delta_timestamps
@pytest.fixture(scope="module") @pytest.fixture(name="slightly_off_delta_timestamps_factory", scope="module")
def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory): def fixture_slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
def _create_slightly_off_delta_timestamps( def _create_slightly_off_delta_timestamps(
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None
) -> dict: ) -> dict:
if keys is None:
keys = DUMMY_MOTOR_FEATURES
delta_timestamps = valid_delta_timestamps_factory(fps, keys) delta_timestamps = valid_delta_timestamps_factory(fps, keys)
# Modify a single timestamp just inside tolerance # Modify a single timestamp just inside tolerance
for key in delta_timestamps: for key in delta_timestamps:
@@ -128,9 +136,11 @@ def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
return _create_slightly_off_delta_timestamps return _create_slightly_off_delta_timestamps
@pytest.fixture(scope="module") @pytest.fixture(name="delta_indices_factory", scope="module")
def delta_indices_factory(): def fixture_delta_indices_factory():
def _delta_indices(keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)) -> dict: def _delta_indices(keys: list | None = None, min_max_range: tuple[int, int] = (-10, 10)) -> dict:
if keys is None:
keys = DUMMY_MOTOR_FEATURES
return {key: list(range(*min_max_range)) for key in keys} return {key: list(range(*min_max_range)) for key in keys}
return _delta_indices return _delta_indices

View File

@@ -38,7 +38,7 @@ def _run_script(path):
def _read_file(path): def _read_file(path):
with open(path) as file: with open(path, encoding="utf-8") as file:
return file.read() return file.read()

View File

@@ -37,8 +37,8 @@ from tests.scripts.save_image_transforms_to_safetensors import ARTIFACT_DIR
from tests.utils import require_x86_64_kernel from tests.utils import require_x86_64_kernel
@pytest.fixture @pytest.fixture(name="color_jitters")
def color_jitters(): def fixture_color_jitters():
return [ return [
v2.ColorJitter(brightness=0.5), v2.ColorJitter(brightness=0.5),
v2.ColorJitter(contrast=0.5), v2.ColorJitter(contrast=0.5),
@@ -46,18 +46,18 @@ def color_jitters():
] ]
@pytest.fixture @pytest.fixture(name="single_transforms")
def single_transforms(): def fixture_single_transforms():
return load_file(ARTIFACT_DIR / "single_transforms.safetensors") return load_file(ARTIFACT_DIR / "single_transforms.safetensors")
@pytest.fixture @pytest.fixture(name="img_tensor")
def img_tensor(single_transforms): def fixture_img_tensor(single_transforms):
return single_transforms["original_frame"] return single_transforms["original_frame"]
@pytest.fixture @pytest.fixture(name="default_transforms")
def default_transforms(): def fixture_default_transforms():
return load_file(ARTIFACT_DIR / "default_transforms.safetensors") return load_file(ARTIFACT_DIR / "default_transforms.safetensors")

View File

@@ -20,8 +20,8 @@ import pytest
from lerobot.common.utils.io_utils import deserialize_json_into_object from lerobot.common.utils.io_utils import deserialize_json_into_object
@pytest.fixture @pytest.fixture(name="tmp_json_file")
def tmp_json_file(tmp_path: Path): def fixture_tmp_json_file(tmp_path: Path):
"""Writes `data` to a temporary JSON file and returns the file's path.""" """Writes `data` to a temporary JSON file and returns the file's path."""
def _write(data: Any) -> Path: def _write(data: Any) -> Path:

View File

@@ -16,8 +16,8 @@ import pytest
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
@pytest.fixture @pytest.fixture(name="mock_metrics")
def mock_metrics(): def fixture_mock_metrics():
return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")} return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")}
@@ -87,6 +87,7 @@ def test_metrics_tracker_getattr(mock_metrics):
_ = tracker.non_existent_metric _ = tracker.non_existent_metric
# TODO(Steven): I don't understand what's supposed to happen here
def test_metrics_tracker_setattr(mock_metrics): def test_metrics_tracker_setattr(mock_metrics):
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics) tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
tracker.loss = 2.0 tracker.loss = 2.0

View File

@@ -74,7 +74,7 @@ def test_non_mutate():
def test_index_error_no_data(): def test_index_error_no_data():
buffer, _ = make_new_buffer() buffer, _ = make_new_buffer()
with pytest.raises(IndexError): with pytest.raises(IndexError):
buffer[0] _ = buffer[0]
def test_index_error_with_data(): def test_index_error_with_data():
@@ -83,9 +83,9 @@ def test_index_error_with_data():
new_data = make_spoof_data_frames(1, n_frames) new_data = make_spoof_data_frames(1, n_frames)
buffer.add_data(new_data) buffer.add_data(new_data)
with pytest.raises(IndexError): with pytest.raises(IndexError):
buffer[n_frames] _ = buffer[n_frames]
with pytest.raises(IndexError): with pytest.raises(IndexError):
buffer[-n_frames - 1] _ = buffer[-n_frames - 1]
@pytest.mark.parametrize("do_reload", [False, True]) @pytest.mark.parametrize("do_reload", [False, True])
@@ -185,7 +185,7 @@ def test_delta_timestamps_outside_tolerance_inside_episode_range():
buffer.add_data(new_data) buffer.add_data(new_data)
buffer.tolerance_s = 0.04 buffer.tolerance_s = 0.04
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
buffer[2] _ = buffer[2]
def test_delta_timestamps_outside_tolerance_outside_episode_range(): def test_delta_timestamps_outside_tolerance_outside_episode_range():
@@ -229,6 +229,7 @@ def test_compute_sampler_weights_trivial(
weights = compute_sampler_weights( weights = compute_sampler_weights(
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
) )
expected_weights: torch.Tensor = None
if offline_dataset_size == 0 or online_dataset_size == 0: if offline_dataset_size == 0 or online_dataset_size == 0:
expected_weights = torch.ones(offline_dataset_size + online_dataset_size) expected_weights = torch.ones(offline_dataset_size + online_dataset_size)
elif online_sampling_ratio == 0: elif online_sampling_ratio == 0:

View File

@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# pylint: disable=redefined-outer-name, unused-argument
import inspect import inspect
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# pylint: disable=redefined-outer-name, unused-argument
import random import random
import numpy as np import numpy as np

View File

@@ -32,16 +32,16 @@ from lerobot.common.utils.import_utils import is_package_available
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu" DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu"
TEST_ROBOT_TYPES = [] TEST_ROBOT_TYPES = []
for robot_type in available_robots: for available_robot_type in available_robots:
TEST_ROBOT_TYPES += [(robot_type, True), (robot_type, False)] TEST_ROBOT_TYPES += [(available_robot_type, True), (available_robot_type, False)]
TEST_CAMERA_TYPES = [] TEST_CAMERA_TYPES = []
for camera_type in available_cameras: for available_camera_type in available_cameras:
TEST_CAMERA_TYPES += [(camera_type, True), (camera_type, False)] TEST_CAMERA_TYPES += [(available_camera_type, True), (available_camera_type, False)]
TEST_MOTOR_TYPES = [] TEST_MOTOR_TYPES = []
for motor_type in available_motors: for available_motor_type in available_motors:
TEST_MOTOR_TYPES += [(motor_type, True), (motor_type, False)] TEST_MOTOR_TYPES += [(available_motor_type, True), (available_motor_type, False)]
# Camera indices used for connecting physical cameras # Camera indices used for connecting physical cameras
OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0)) OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0))
@@ -72,7 +72,6 @@ def require_x86_64_kernel(func):
""" """
Decorator that skips the test if plateform device is not an x86_64 cpu. Decorator that skips the test if plateform device is not an x86_64 cpu.
""" """
from functools import wraps
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
@@ -87,7 +86,6 @@ def require_cpu(func):
""" """
Decorator that skips the test if device is not cpu. Decorator that skips the test if device is not cpu.
""" """
from functools import wraps
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
@@ -102,7 +100,6 @@ def require_cuda(func):
""" """
Decorator that skips the test if cuda is not available. Decorator that skips the test if cuda is not available.
""" """
from functools import wraps
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
@@ -288,17 +285,17 @@ def mock_calibration_dir(calibration_dir):
"motor_names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], "motor_names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
} }
Path(str(calibration_dir)).mkdir(parents=True, exist_ok=True) Path(str(calibration_dir)).mkdir(parents=True, exist_ok=True)
with open(calibration_dir / "main_follower.json", "w") as f: with open(calibration_dir / "main_follower.json", "w", encoding="utf-8") as f:
json.dump(example_calib, f) json.dump(example_calib, f)
with open(calibration_dir / "main_leader.json", "w") as f: with open(calibration_dir / "main_leader.json", "w", encoding="utf-8") as f:
json.dump(example_calib, f) json.dump(example_calib, f)
with open(calibration_dir / "left_follower.json", "w") as f: with open(calibration_dir / "left_follower.json", "w", encoding="utf-8") as f:
json.dump(example_calib, f) json.dump(example_calib, f)
with open(calibration_dir / "left_leader.json", "w") as f: with open(calibration_dir / "left_leader.json", "w", encoding="utf-8") as f:
json.dump(example_calib, f) json.dump(example_calib, f)
with open(calibration_dir / "right_follower.json", "w") as f: with open(calibration_dir / "right_follower.json", "w", encoding="utf-8") as f:
json.dump(example_calib, f) json.dump(example_calib, f)
with open(calibration_dir / "right_leader.json", "w") as f: with open(calibration_dir / "right_leader.json", "w", encoding="utf-8") as f:
json.dump(example_calib, f) json.dump(example_calib, f)