Compare commits
15 Commits
recovered-
...
fix/lint_w
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e511e7eda5 | ||
|
|
f5ed3723f0 | ||
|
|
b104be0d04 | ||
|
|
f9e4a1f5c4 | ||
|
|
0eb56cec14 | ||
|
|
e59ef036e1 | ||
|
|
9b380eaf67 | ||
|
|
1187604ba0 | ||
|
|
5c6f2d2cd0 | ||
|
|
652fedf69c | ||
|
|
85214ec303 | ||
|
|
dffa5a18db | ||
|
|
301f152a34 | ||
|
|
0ed08c0b1f | ||
|
|
254bc707e7 |
@@ -67,6 +67,7 @@ def parse_int_or_none(value) -> int | None:
|
|||||||
def check_datasets_formats(repo_ids: list) -> None:
|
def check_datasets_formats(repo_ids: list) -> None:
|
||||||
for repo_id in repo_ids:
|
for repo_id in repo_ids:
|
||||||
dataset = LeRobotDataset(repo_id)
|
dataset = LeRobotDataset(repo_id)
|
||||||
|
# TODO(Steven): Seems this API has changed
|
||||||
if dataset.video:
|
if dataset.video:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
|
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
|
||||||
|
|||||||
@@ -222,7 +222,7 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# To try this script, modify the repo id with your own HuggingFace user (e.g cadene/pusht)
|
# To try this script, modify the repo id with your own HuggingFace user (e.g cadene/pusht)
|
||||||
repo_id = "lerobot/pusht"
|
repository_id = "lerobot/pusht"
|
||||||
|
|
||||||
modes = ["video", "image", "keypoints"]
|
modes = ["video", "image", "keypoints"]
|
||||||
# Uncomment if you want to try with a specific mode
|
# Uncomment if you want to try with a specific mode
|
||||||
@@ -230,13 +230,13 @@ if __name__ == "__main__":
|
|||||||
# modes = ["image"]
|
# modes = ["image"]
|
||||||
# modes = ["keypoints"]
|
# modes = ["keypoints"]
|
||||||
|
|
||||||
raw_dir = Path("data/lerobot-raw/pusht_raw")
|
data_dir = Path("data/lerobot-raw/pusht_raw")
|
||||||
for mode in modes:
|
for available_mode in modes:
|
||||||
if mode in ["image", "keypoints"]:
|
if available_mode in ["image", "keypoints"]:
|
||||||
repo_id += f"_{mode}"
|
repository_id += f"_{available_mode}"
|
||||||
|
|
||||||
# download and load raw dataset, create LeRobotDataset, populate it, push to hub
|
# download and load raw dataset, create LeRobotDataset, populate it, push to hub
|
||||||
main(raw_dir, repo_id=repo_id, mode=mode)
|
main(data_dir, repo_id=repository_id, mode=available_mode)
|
||||||
|
|
||||||
# Uncomment if you want to load the local dataset and explore it
|
# Uncomment if you want to load the local dataset and explore it
|
||||||
# dataset = LeRobotDataset(repo_id=repo_id)
|
# dataset = LeRobotDataset(repo_id=repo_id)
|
||||||
|
|||||||
@@ -13,8 +13,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
|
||||||
from pprint import pformat
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -98,17 +96,17 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||||
dataset = MultiLeRobotDataset(
|
# dataset = MultiLeRobotDataset(
|
||||||
cfg.dataset.repo_id,
|
# cfg.dataset.repo_id,
|
||||||
# TODO(aliberts): add proper support for multi dataset
|
# # TODO(aliberts): add proper support for multi dataset
|
||||||
# delta_timestamps=delta_timestamps,
|
# # delta_timestamps=delta_timestamps,
|
||||||
image_transforms=image_transforms,
|
# image_transforms=image_transforms,
|
||||||
video_backend=cfg.dataset.video_backend,
|
# video_backend=cfg.dataset.video_backend,
|
||||||
)
|
# )
|
||||||
logging.info(
|
# logging.info(
|
||||||
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
# "Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
||||||
f"{pformat(dataset.repo_id_to_index, indent=2)}"
|
# f"{pformat(dataset.repo_id_to_index, indent=2)}"
|
||||||
)
|
# )
|
||||||
|
|
||||||
if cfg.dataset.use_imagenet_stats:
|
if cfg.dataset.use_imagenet_stats:
|
||||||
for key in dataset.meta.camera_keys:
|
for key in dataset.meta.camera_keys:
|
||||||
|
|||||||
@@ -81,21 +81,21 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
|
|||||||
print(f"Error writing image {fpath}: {e}")
|
print(f"Error writing image {fpath}: {e}")
|
||||||
|
|
||||||
|
|
||||||
def worker_thread_loop(queue: queue.Queue):
|
def worker_thread_loop(task_queue: queue.Queue):
|
||||||
while True:
|
while True:
|
||||||
item = queue.get()
|
item = task_queue.get()
|
||||||
if item is None:
|
if item is None:
|
||||||
queue.task_done()
|
task_queue.task_done()
|
||||||
break
|
break
|
||||||
image_array, fpath = item
|
image_array, fpath = item
|
||||||
write_image(image_array, fpath)
|
write_image(image_array, fpath)
|
||||||
queue.task_done()
|
task_queue.task_done()
|
||||||
|
|
||||||
|
|
||||||
def worker_process(queue: queue.Queue, num_threads: int):
|
def worker_process(task_queue: queue.Queue, num_threads: int):
|
||||||
threads = []
|
threads = []
|
||||||
for _ in range(num_threads):
|
for _ in range(num_threads):
|
||||||
t = threading.Thread(target=worker_thread_loop, args=(queue,))
|
t = threading.Thread(target=worker_thread_loop, args=(task_queue,))
|
||||||
t.daemon = True
|
t.daemon = True
|
||||||
t.start()
|
t.start()
|
||||||
threads.append(t)
|
threads.append(t)
|
||||||
|
|||||||
@@ -87,6 +87,7 @@ class LeRobotDatasetMetadata:
|
|||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
self.revision = revision if revision else CODEBASE_VERSION
|
self.revision = revision if revision else CODEBASE_VERSION
|
||||||
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||||
|
self.stats = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if force_cache_sync:
|
if force_cache_sync:
|
||||||
@@ -102,10 +103,10 @@ class LeRobotDatasetMetadata:
|
|||||||
|
|
||||||
def load_metadata(self):
|
def load_metadata(self):
|
||||||
self.info = load_info(self.root)
|
self.info = load_info(self.root)
|
||||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
check_version_compatibility(self.repo_id, self.version, CODEBASE_VERSION)
|
||||||
self.tasks, self.task_to_task_index = load_tasks(self.root)
|
self.tasks, self.task_to_task_index = load_tasks(self.root)
|
||||||
self.episodes = load_episodes(self.root)
|
self.episodes = load_episodes(self.root)
|
||||||
if self._version < packaging.version.parse("v2.1"):
|
if self.version < packaging.version.parse("v2.1"):
|
||||||
self.stats = load_stats(self.root)
|
self.stats = load_stats(self.root)
|
||||||
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
|
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
|
||||||
else:
|
else:
|
||||||
@@ -127,7 +128,7 @@ class LeRobotDatasetMetadata:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _version(self) -> packaging.version.Version:
|
def version(self) -> packaging.version.Version:
|
||||||
"""Codebase version used to create this dataset."""
|
"""Codebase version used to create this dataset."""
|
||||||
return packaging.version.parse(self.info["codebase_version"])
|
return packaging.version.parse(self.info["codebase_version"])
|
||||||
|
|
||||||
@@ -321,8 +322,9 @@ class LeRobotDatasetMetadata:
|
|||||||
robot_type = robot.robot_type
|
robot_type = robot.robot_type
|
||||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
"Some cameras in your %s robot don't have an fps matching the fps of your dataset."
|
||||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks."
|
"In this case, frames from lower fps cameras will be repeated to fill in the blanks.",
|
||||||
|
robot.robot_type,
|
||||||
)
|
)
|
||||||
elif features is None:
|
elif features is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -486,7 +488,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
self.meta = LeRobotDatasetMetadata(
|
self.meta = LeRobotDatasetMetadata(
|
||||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||||
)
|
)
|
||||||
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
|
if self.episodes is not None and self.meta.version >= packaging.version.parse("v2.1"):
|
||||||
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
||||||
self.stats = aggregate_stats(episodes_stats)
|
self.stats = aggregate_stats(episodes_stats)
|
||||||
|
|
||||||
@@ -518,7 +520,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
self,
|
self,
|
||||||
branch: str | None = None,
|
branch: str | None = None,
|
||||||
tags: list | None = None,
|
tags: list | None = None,
|
||||||
license: str | None = "apache-2.0",
|
dataset_license: str | None = "apache-2.0",
|
||||||
tag_version: bool = True,
|
tag_version: bool = True,
|
||||||
push_videos: bool = True,
|
push_videos: bool = True,
|
||||||
private: bool = False,
|
private: bool = False,
|
||||||
@@ -561,7 +563,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
|
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
|
||||||
card = create_lerobot_dataset_card(
|
card = create_lerobot_dataset_card(
|
||||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
tags=tags, dataset_info=self.meta.info, license=dataset_license, **card_kwargs
|
||||||
)
|
)
|
||||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
||||||
|
|
||||||
@@ -842,6 +844,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
|
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
|
||||||
None.
|
None.
|
||||||
"""
|
"""
|
||||||
|
episode_buffer = None
|
||||||
if not episode_data:
|
if not episode_data:
|
||||||
episode_buffer = self.episode_buffer
|
episode_buffer = self.episode_buffer
|
||||||
|
|
||||||
@@ -1086,8 +1089,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||||||
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
||||||
extra_keys = set(ds.features).difference(intersection_features)
|
extra_keys = set(ds.features).difference(intersection_features)
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
"keys %s of %s were disabled as they are not contained in all the other datasets.",
|
||||||
"other datasets."
|
extra_keys,
|
||||||
|
repo_id,
|
||||||
)
|
)
|
||||||
self.disabled_features.update(extra_keys)
|
self.disabled_features.update(extra_keys)
|
||||||
|
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ def rechunk_recompress_array(group, name, chunks=None, chunk_length=None, compre
|
|||||||
# rechunk recompress
|
# rechunk recompress
|
||||||
group.move(name, tmp_key)
|
group.move(name, tmp_key)
|
||||||
old_arr = group[tmp_key]
|
old_arr = group[tmp_key]
|
||||||
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy(
|
||||||
source=old_arr,
|
source=old_arr,
|
||||||
dest=group,
|
dest=group,
|
||||||
name=name,
|
name=name,
|
||||||
@@ -192,7 +192,7 @@ class ReplayBuffer:
|
|||||||
else:
|
else:
|
||||||
root = zarr.group(store=store)
|
root = zarr.group(store=store)
|
||||||
# copy without recompression
|
# copy without recompression
|
||||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
|
||||||
source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists
|
source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists
|
||||||
)
|
)
|
||||||
data_group = root.create_group("data", overwrite=True)
|
data_group = root.create_group("data", overwrite=True)
|
||||||
@@ -205,7 +205,7 @@ class ReplayBuffer:
|
|||||||
if cks == value.chunks and cpr == value.compressor:
|
if cks == value.chunks and cpr == value.compressor:
|
||||||
# copy without recompression
|
# copy without recompression
|
||||||
this_path = "/data/" + key
|
this_path = "/data/" + key
|
||||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
|
||||||
source=src_store,
|
source=src_store,
|
||||||
dest=store,
|
dest=store,
|
||||||
source_path=this_path,
|
source_path=this_path,
|
||||||
@@ -214,7 +214,7 @@ class ReplayBuffer:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# copy with recompression
|
# copy with recompression
|
||||||
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy(
|
||||||
source=value,
|
source=value,
|
||||||
dest=data_group,
|
dest=data_group,
|
||||||
name=key,
|
name=key,
|
||||||
@@ -275,7 +275,7 @@ class ReplayBuffer:
|
|||||||
compressors = {}
|
compressors = {}
|
||||||
if self.backend == "zarr":
|
if self.backend == "zarr":
|
||||||
# recompression free copy
|
# recompression free copy
|
||||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
|
||||||
source=self.root.store,
|
source=self.root.store,
|
||||||
dest=store,
|
dest=store,
|
||||||
source_path="/meta",
|
source_path="/meta",
|
||||||
@@ -297,7 +297,7 @@ class ReplayBuffer:
|
|||||||
if cks == value.chunks and cpr == value.compressor:
|
if cks == value.chunks and cpr == value.compressor:
|
||||||
# copy without recompression
|
# copy without recompression
|
||||||
this_path = "/data/" + key
|
this_path = "/data/" + key
|
||||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
|
||||||
source=self.root.store,
|
source=self.root.store,
|
||||||
dest=store,
|
dest=store,
|
||||||
source_path=this_path,
|
source_path=this_path,
|
||||||
|
|||||||
@@ -162,9 +162,9 @@ def download_raw(raw_dir: Path, repo_id: str):
|
|||||||
)
|
)
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}")
|
logging.info("Start downloading from huggingface.co/%s for %s", user_id, dataset_id)
|
||||||
snapshot_download(repo_id, repo_type="dataset", local_dir=raw_dir)
|
snapshot_download(repo_id, repo_type="dataset", local_dir=raw_dir)
|
||||||
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")
|
logging.info("Finish downloading from huggingface.co/%s for %s", user_id, dataset_id)
|
||||||
|
|
||||||
|
|
||||||
def download_all_raw_datasets(data_dir: Path | None = None):
|
def download_all_raw_datasets(data_dir: Path | None = None):
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ def check_format(raw_dir) -> bool:
|
|||||||
assert data[f"/observations/images/{camera}"].ndim == 2
|
assert data[f"/observations/images/{camera}"].ndim == 2
|
||||||
else:
|
else:
|
||||||
assert data[f"/observations/images/{camera}"].ndim == 4
|
assert data[f"/observations/images/{camera}"].ndim == 4
|
||||||
b, h, w, c = data[f"/observations/images/{camera}"].shape
|
_, h, w, c = data[f"/observations/images/{camera}"].shape
|
||||||
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
||||||
|
|
||||||
|
|
||||||
@@ -103,6 +103,7 @@ def load_from_raw(
|
|||||||
|
|
||||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||||
action = torch.from_numpy(ep["/action"][:])
|
action = torch.from_numpy(ep["/action"][:])
|
||||||
|
velocity = None
|
||||||
if "/observations/qvel" in ep:
|
if "/observations/qvel" in ep:
|
||||||
velocity = torch.from_numpy(ep["/observations/qvel"][:])
|
velocity = torch.from_numpy(ep["/observations/qvel"][:])
|
||||||
if "/observations/effort" in ep:
|
if "/observations/effort" in ep:
|
||||||
|
|||||||
@@ -96,6 +96,7 @@ def from_raw_to_lerobot_format(
|
|||||||
if fps is None:
|
if fps is None:
|
||||||
fps = 30
|
fps = 30
|
||||||
|
|
||||||
|
# TODO(Steven): Is this meant to call cam_png_format.load_from_raw?
|
||||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||||
hf_dataset = to_hf_dataset(data_dict, video)
|
hf_dataset = to_hf_dataset(data_dict, video)
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
|
|||||||
@@ -42,7 +42,9 @@ def check_format(raw_dir) -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
def load_from_raw(
|
||||||
|
raw_dir: Path, videos_dir: Path, fps: int, _video: bool, _episodes: list[int] | None = None
|
||||||
|
):
|
||||||
# Load data stream that will be used as reference for the timestamps synchronization
|
# Load data stream that will be used as reference for the timestamps synchronization
|
||||||
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
|
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
|
||||||
if len(reference_files) == 0:
|
if len(reference_files) == 0:
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers
|
|||||||
|
|
||||||
num_images = len(imgs_array)
|
num_images = len(imgs_array)
|
||||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
|
_ = [executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
|
||||||
|
|
||||||
|
|
||||||
def get_default_encoding() -> dict:
|
def get_default_encoding() -> dict:
|
||||||
@@ -92,24 +92,23 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc
|
|||||||
episode_data_index = {"from": [], "to": []}
|
episode_data_index = {"from": [], "to": []}
|
||||||
|
|
||||||
current_episode = None
|
current_episode = None
|
||||||
"""
|
# The episode_index is a list of integers, each representing the episode index of the corresponding example.
|
||||||
The episode_index is a list of integers, each representing the episode index of the corresponding example.
|
# For instance, the following is a valid episode_index:
|
||||||
For instance, the following is a valid episode_index:
|
# [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
|
||||||
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
|
#
|
||||||
|
# Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
|
||||||
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
|
# ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
|
||||||
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
|
# {
|
||||||
{
|
# "from": [0, 3, 7],
|
||||||
"from": [0, 3, 7],
|
# "to": [3, 7, 12]
|
||||||
"to": [3, 7, 12]
|
# }
|
||||||
}
|
|
||||||
"""
|
|
||||||
if len(hf_dataset) == 0:
|
if len(hf_dataset) == 0:
|
||||||
episode_data_index = {
|
episode_data_index = {
|
||||||
"from": torch.tensor([]),
|
"from": torch.tensor([]),
|
||||||
"to": torch.tensor([]),
|
"to": torch.tensor([]),
|
||||||
}
|
}
|
||||||
return episode_data_index
|
return episode_data_index
|
||||||
|
idx = None
|
||||||
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
|
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
|
||||||
if episode_idx != current_episode:
|
if episode_idx != current_episode:
|
||||||
# We encountered a new episode, so we append its starting location to the "from" list
|
# We encountered a new episode, so we append its starting location to the "from" list
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from torchvision.transforms.v2 import Transform
|
|||||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Steven): Missing transform() implementation
|
||||||
class RandomSubsetApply(Transform):
|
class RandomSubsetApply(Transform):
|
||||||
"""Apply a random subset of N transformations from a list of transformations.
|
"""Apply a random subset of N transformations from a list of transformations.
|
||||||
|
|
||||||
@@ -218,6 +219,7 @@ def make_transform_from_config(cfg: ImageTransformConfig):
|
|||||||
raise ValueError(f"Transform '{cfg.type}' is not valid.")
|
raise ValueError(f"Transform '{cfg.type}' is not valid.")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Steven): Missing transform() implementation
|
||||||
class ImageTransforms(Transform):
|
class ImageTransforms(Transform):
|
||||||
"""A class to compose image transforms based on configuration."""
|
"""A class to compose image transforms based on configuration."""
|
||||||
|
|
||||||
|
|||||||
@@ -135,21 +135,21 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
|||||||
|
|
||||||
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
|
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
|
||||||
# Embed image bytes into the table before saving to parquet
|
# Embed image bytes into the table before saving to parquet
|
||||||
format = dataset.format
|
ds_format = dataset.format
|
||||||
dataset = dataset.with_format("arrow")
|
dataset = dataset.with_format("arrow")
|
||||||
dataset = dataset.map(embed_table_storage, batched=False)
|
dataset = dataset.map(embed_table_storage, batched=False)
|
||||||
dataset = dataset.with_format(**format)
|
dataset = dataset.with_format(**ds_format)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def load_json(fpath: Path) -> Any:
|
def load_json(fpath: Path) -> Any:
|
||||||
with open(fpath) as f:
|
with open(fpath, encoding="utf-8") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
def write_json(data: dict, fpath: Path) -> None:
|
def write_json(data: dict, fpath: Path) -> None:
|
||||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||||
with open(fpath, "w") as f:
|
with open(fpath, "w", encoding="utf-8") as f:
|
||||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
@@ -300,7 +300,7 @@ def check_version_compatibility(
|
|||||||
if v_check.major < v_current.major and enforce_breaking_major:
|
if v_check.major < v_current.major and enforce_breaking_major:
|
||||||
raise BackwardCompatibilityError(repo_id, v_check)
|
raise BackwardCompatibilityError(repo_id, v_check)
|
||||||
elif v_check.minor < v_current.minor:
|
elif v_check.minor < v_current.minor:
|
||||||
logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check))
|
logging.warning("%s", V21_MESSAGE.format(repo_id=repo_id, version=v_check))
|
||||||
|
|
||||||
|
|
||||||
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
|
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
|
||||||
@@ -348,7 +348,9 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
|
|||||||
if compatibles:
|
if compatibles:
|
||||||
return_version = max(compatibles)
|
return_version = max(compatibles)
|
||||||
if return_version < target_version:
|
if return_version < target_version:
|
||||||
logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}")
|
logging.warning(
|
||||||
|
"Revision %s for %s not found, using version v%s", version, repo_id, return_version
|
||||||
|
)
|
||||||
return f"v{return_version}"
|
return f"v{return_version}"
|
||||||
|
|
||||||
lower_major = [v for v in hub_versions if v.major < target_version.major]
|
lower_major = [v for v in hub_versions if v.major < target_version.major]
|
||||||
@@ -403,7 +405,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
|||||||
for key, ft in features.items():
|
for key, ft in features.items():
|
||||||
shape = ft["shape"]
|
shape = ft["shape"]
|
||||||
if ft["dtype"] in ["image", "video"]:
|
if ft["dtype"] in ["image", "video"]:
|
||||||
type = FeatureType.VISUAL
|
feature_type = FeatureType.VISUAL
|
||||||
if len(shape) != 3:
|
if len(shape) != 3:
|
||||||
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
|
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
|
||||||
|
|
||||||
@@ -412,16 +414,16 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
|||||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||||
shape = (shape[2], shape[0], shape[1])
|
shape = (shape[2], shape[0], shape[1])
|
||||||
elif key == "observation.environment_state":
|
elif key == "observation.environment_state":
|
||||||
type = FeatureType.ENV
|
feature_type = FeatureType.ENV
|
||||||
elif key.startswith("observation"):
|
elif key.startswith("observation"):
|
||||||
type = FeatureType.STATE
|
feature_type = FeatureType.STATE
|
||||||
elif key == "action":
|
elif key == "action":
|
||||||
type = FeatureType.ACTION
|
feature_type = FeatureType.ACTION
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
policy_features[key] = PolicyFeature(
|
policy_features[key] = PolicyFeature(
|
||||||
type=type,
|
type=feature_type,
|
||||||
shape=shape,
|
shape=shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -871,11 +871,11 @@ def batch_convert():
|
|||||||
try:
|
try:
|
||||||
convert_dataset(repo_id, LOCAL_DIR, **kwargs)
|
convert_dataset(repo_id, LOCAL_DIR, **kwargs)
|
||||||
status = f"{repo_id}: success."
|
status = f"{repo_id}: success."
|
||||||
with open(logfile, "a") as file:
|
with open(logfile, "a", encoding="utf-8") as file:
|
||||||
file.write(status + "\n")
|
file.write(status + "\n")
|
||||||
except Exception:
|
except Exception:
|
||||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||||
with open(logfile, "a") as file:
|
with open(logfile, "a", encoding="utf-8") as file:
|
||||||
file.write(status + "\n")
|
file.write(status + "\n")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -190,11 +190,11 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
|
|||||||
|
|
||||||
json_path = v2_dir / STATS_PATH
|
json_path = v2_dir / STATS_PATH
|
||||||
json_path.parent.mkdir(exist_ok=True, parents=True)
|
json_path.parent.mkdir(exist_ok=True, parents=True)
|
||||||
with open(json_path, "w") as f:
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(serialized_stats, f, indent=4)
|
json.dump(serialized_stats, f, indent=4)
|
||||||
|
|
||||||
# Sanity check
|
# Sanity check
|
||||||
with open(json_path) as f:
|
with open(json_path, encoding="utf-8") as f:
|
||||||
stats_json = json.load(f)
|
stats_json = json.load(f)
|
||||||
|
|
||||||
stats_json = flatten_dict(stats_json)
|
stats_json = flatten_dict(stats_json)
|
||||||
@@ -213,7 +213,7 @@ def get_features_from_hf_dataset(
|
|||||||
dtype = ft.dtype
|
dtype = ft.dtype
|
||||||
shape = (1,)
|
shape = (1,)
|
||||||
names = None
|
names = None
|
||||||
if isinstance(ft, datasets.Sequence):
|
elif isinstance(ft, datasets.Sequence):
|
||||||
assert isinstance(ft.feature, datasets.Value)
|
assert isinstance(ft.feature, datasets.Value)
|
||||||
dtype = ft.feature.dtype
|
dtype = ft.feature.dtype
|
||||||
shape = (ft.length,)
|
shape = (ft.length,)
|
||||||
@@ -232,6 +232,8 @@ def get_features_from_hf_dataset(
|
|||||||
dtype = "video"
|
dtype = "video"
|
||||||
shape = None # Add shape later
|
shape = None # Add shape later
|
||||||
names = ["height", "width", "channels"]
|
names = ["height", "width", "channels"]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Feature type {ft._type} not supported.")
|
||||||
|
|
||||||
features[key] = {
|
features[key] = {
|
||||||
"dtype": dtype,
|
"dtype": dtype,
|
||||||
@@ -358,9 +360,9 @@ def move_videos(
|
|||||||
if len(video_dirs) == 1:
|
if len(video_dirs) == 1:
|
||||||
video_path = video_dirs[0] / video_file
|
video_path = video_dirs[0] / video_file
|
||||||
else:
|
else:
|
||||||
for dir in video_dirs:
|
for v_dir in video_dirs:
|
||||||
if (dir / video_file).is_file():
|
if (v_dir / video_file).is_file():
|
||||||
video_path = dir / video_file
|
video_path = v_dir / video_file
|
||||||
break
|
break
|
||||||
|
|
||||||
video_path.rename(work_dir / target_path)
|
video_path.rename(work_dir / target_path)
|
||||||
@@ -652,6 +654,7 @@ def main():
|
|||||||
if not args.local_dir:
|
if not args.local_dir:
|
||||||
args.local_dir = Path("/tmp/lerobot_dataset_v2")
|
args.local_dir = Path("/tmp/lerobot_dataset_v2")
|
||||||
|
|
||||||
|
robot_config = None
|
||||||
if args.robot is not None:
|
if args.robot is not None:
|
||||||
robot_config = make_robot_config(args.robot)
|
robot_config = make_robot_config(args.robot)
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ def fix_dataset(repo_id: str) -> str:
|
|||||||
return f"{repo_id}: skipped (no diff)"
|
return f"{repo_id}: skipped (no diff)"
|
||||||
|
|
||||||
if diff_meta_parquet:
|
if diff_meta_parquet:
|
||||||
logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
|
logging.warning("In info.json not in parquet: %s", meta_features - parquet_features)
|
||||||
assert diff_meta_parquet == {"language_instruction"}
|
assert diff_meta_parquet == {"language_instruction"}
|
||||||
lerobot_metadata.features.pop("language_instruction")
|
lerobot_metadata.features.pop("language_instruction")
|
||||||
write_info(lerobot_metadata.info, lerobot_metadata.root)
|
write_info(lerobot_metadata.info, lerobot_metadata.root)
|
||||||
@@ -79,7 +79,7 @@ def batch_fix():
|
|||||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||||
|
|
||||||
logging.info(status)
|
logging.info(status)
|
||||||
with open(logfile, "a") as file:
|
with open(logfile, "a", encoding="utf-8") as file:
|
||||||
file.write(status + "\n")
|
file.write(status + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ def batch_convert():
|
|||||||
except Exception:
|
except Exception:
|
||||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||||
|
|
||||||
with open(logfile, "a") as file:
|
with open(logfile, "a", encoding="utf-8") as file:
|
||||||
file.write(status + "\n")
|
file.write(status + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -45,6 +45,9 @@ V21 = "v2.1"
|
|||||||
|
|
||||||
|
|
||||||
class SuppressWarnings:
|
class SuppressWarnings:
|
||||||
|
def __init__(self):
|
||||||
|
self.previous_level = None
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.previous_level = logging.getLogger().getEffectiveLevel()
|
self.previous_level = logging.getLogger().getEffectiveLevel()
|
||||||
logging.getLogger().setLevel(logging.ERROR)
|
logging.getLogger().setLevel(logging.ERROR)
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ def decode_video_frames_torchvision(
|
|||||||
for frame in reader:
|
for frame in reader:
|
||||||
current_ts = frame["pts"]
|
current_ts = frame["pts"]
|
||||||
if log_loaded_timestamps:
|
if log_loaded_timestamps:
|
||||||
logging.info(f"frame loaded at timestamp={current_ts:.4f}")
|
logging.info("frame loaded at timestamp=%.4f", current_ts)
|
||||||
loaded_frames.append(frame["data"])
|
loaded_frames.append(frame["data"])
|
||||||
loaded_ts.append(current_ts)
|
loaded_ts.append(current_ts)
|
||||||
if current_ts >= last_ts:
|
if current_ts >= last_ts:
|
||||||
@@ -118,7 +118,7 @@ def decode_video_frames_torchvision(
|
|||||||
closest_ts = loaded_ts[argmin_]
|
closest_ts = loaded_ts[argmin_]
|
||||||
|
|
||||||
if log_loaded_timestamps:
|
if log_loaded_timestamps:
|
||||||
logging.info(f"{closest_ts=}")
|
logging.info("closest_ts=%s", closest_ts)
|
||||||
|
|
||||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||||
closest_frames = closest_frames.type(torch.float32) / 255
|
closest_frames = closest_frames.type(torch.float32) / 255
|
||||||
@@ -227,7 +227,9 @@ def get_audio_info(video_path: Path | str) -> dict:
|
|||||||
"json",
|
"json",
|
||||||
str(video_path),
|
str(video_path),
|
||||||
]
|
]
|
||||||
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
result = subprocess.run(
|
||||||
|
ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True
|
||||||
|
)
|
||||||
if result.returncode != 0:
|
if result.returncode != 0:
|
||||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||||
|
|
||||||
@@ -263,7 +265,9 @@ def get_video_info(video_path: Path | str) -> dict:
|
|||||||
"json",
|
"json",
|
||||||
str(video_path),
|
str(video_path),
|
||||||
]
|
]
|
||||||
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
result = subprocess.run(
|
||||||
|
ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True
|
||||||
|
)
|
||||||
if result.returncode != 0:
|
if result.returncode != 0:
|
||||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,8 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
|||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
return self.get_choice_name(self.__class__)
|
return self.get_choice_name(self.__class__)
|
||||||
|
|
||||||
@abc.abstractproperty
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
def gym_kwargs(self) -> dict:
|
def gym_kwargs(self) -> dict:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
|||||||
return "adam"
|
return "adam"
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def build(self) -> torch.optim.Optimizer:
|
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ class ACTConfig(PreTrainedConfig):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
"""Input validation (not exhaustive)."""
|
# Input validation (not exhaustive).
|
||||||
if not self.vision_backbone.startswith("resnet"):
|
if not self.vision_backbone.startswith("resnet"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||||
|
|||||||
@@ -222,6 +222,8 @@ class ACTTemporalEnsembler:
|
|||||||
self.chunk_size = chunk_size
|
self.chunk_size = chunk_size
|
||||||
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
|
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
|
||||||
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
|
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
|
||||||
|
self.ensembled_actions = None
|
||||||
|
self.ensembled_actions_count = None
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
"""Input validation (not exhaustive)."""
|
# Input validation (not exhaustive).
|
||||||
if not self.vision_backbone.startswith("resnet"):
|
if not self.vision_backbone.startswith("resnet"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||||
|
|||||||
@@ -170,6 +170,7 @@ def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMSche
|
|||||||
raise ValueError(f"Unsupported noise scheduler type {name}")
|
raise ValueError(f"Unsupported noise scheduler type {name}")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Steven): Missing forward() implementation
|
||||||
class DiffusionModel(nn.Module):
|
class DiffusionModel(nn.Module):
|
||||||
def __init__(self, config: DiffusionConfig):
|
def __init__(self, config: DiffusionConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -203,6 +204,7 @@ class DiffusionModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if config.num_inference_steps is None:
|
if config.num_inference_steps is None:
|
||||||
|
# TODO(Steven): Consider type check?
|
||||||
self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps
|
self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps
|
||||||
else:
|
else:
|
||||||
self.num_inference_steps = config.num_inference_steps
|
self.num_inference_steps = config.num_inference_steps
|
||||||
@@ -333,7 +335,7 @@ class DiffusionModel(nn.Module):
|
|||||||
# Sample a random noising timestep for each item in the batch.
|
# Sample a random noising timestep for each item in the batch.
|
||||||
timesteps = torch.randint(
|
timesteps = torch.randint(
|
||||||
low=0,
|
low=0,
|
||||||
high=self.noise_scheduler.config.num_train_timesteps,
|
high=self.noise_scheduler.config.num_train_timesteps, # TODO(Steven): Consider type check?
|
||||||
size=(trajectory.shape[0],),
|
size=(trajectory.shape[0],),
|
||||||
device=trajectory.device,
|
device=trajectory.device,
|
||||||
).long()
|
).long()
|
||||||
|
|||||||
@@ -69,12 +69,12 @@ def create_stats_buffers(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||||
min = torch.ones(shape, dtype=torch.float32) * torch.inf
|
min_norm = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||||
max = torch.ones(shape, dtype=torch.float32) * torch.inf
|
max_norm = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||||
buffer = nn.ParameterDict(
|
buffer = nn.ParameterDict(
|
||||||
{
|
{
|
||||||
"min": nn.Parameter(min, requires_grad=False),
|
"min": nn.Parameter(min_norm, requires_grad=False),
|
||||||
"max": nn.Parameter(max, requires_grad=False),
|
"max": nn.Parameter(max_norm, requires_grad=False),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -170,12 +170,12 @@ class Normalize(nn.Module):
|
|||||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||||
min = buffer["min"]
|
min_norm = buffer["min"]
|
||||||
max = buffer["max"]
|
max_norm = buffer["max"]
|
||||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
assert not torch.isinf(min_norm).any(), _no_stats_error_str("min")
|
||||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
assert not torch.isinf(max_norm).any(), _no_stats_error_str("max")
|
||||||
# normalize to [0,1]
|
# normalize to [0,1]
|
||||||
batch[key] = (batch[key] - min) / (max - min + 1e-8)
|
batch[key] = (batch[key] - min_norm) / (max_norm - min_norm + 1e-8)
|
||||||
# normalize to [-1, 1]
|
# normalize to [-1, 1]
|
||||||
batch[key] = batch[key] * 2 - 1
|
batch[key] = batch[key] * 2 - 1
|
||||||
else:
|
else:
|
||||||
@@ -243,12 +243,12 @@ class Unnormalize(nn.Module):
|
|||||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||||
batch[key] = batch[key] * std + mean
|
batch[key] = batch[key] * std + mean
|
||||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||||
min = buffer["min"]
|
min_norm = buffer["min"]
|
||||||
max = buffer["max"]
|
max_norm = buffer["max"]
|
||||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
assert not torch.isinf(min_norm).any(), _no_stats_error_str("min")
|
||||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
assert not torch.isinf(max_norm).any(), _no_stats_error_str("max")
|
||||||
batch[key] = (batch[key] + 1) / 2
|
batch[key] = (batch[key] + 1) / 2
|
||||||
batch[key] = batch[key] * (max - min) + min
|
batch[key] = batch[key] * (max_norm - min_norm) + min_norm
|
||||||
else:
|
else:
|
||||||
raise ValueError(norm_mode)
|
raise ValueError(norm_mode)
|
||||||
return batch
|
return batch
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ class PI0Config(PreTrainedConfig):
|
|||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
# TODO(Steven): Validate device and amp? in all policy configs?
|
# TODO(Steven): Validate device and amp? in all policy configs?
|
||||||
"""Input validation (not exhaustive)."""
|
# Input validation (not exhaustive).
|
||||||
if self.n_action_steps > self.chunk_size:
|
if self.n_action_steps > self.chunk_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ def main():
|
|||||||
with open(save_dir / "noise.pkl", "rb") as f:
|
with open(save_dir / "noise.pkl", "rb") as f:
|
||||||
noise = pickle.load(f)
|
noise = pickle.load(f)
|
||||||
|
|
||||||
with open(ckpt_jax_dir / "assets/norm_stats.json") as f:
|
with open(ckpt_jax_dir / "assets/norm_stats.json", encoding="utf-8") as f:
|
||||||
norm_stats = json.load(f)
|
norm_stats = json.load(f)
|
||||||
|
|
||||||
# Override stats
|
# Override stats
|
||||||
|
|||||||
@@ -318,7 +318,7 @@ def update_keys_with_prefix(d: dict, prefix: str) -> dict:
|
|||||||
return {f"{prefix}{key}": value for key, value in d.items()}
|
return {f"{prefix}{key}": value for key, value in d.items()}
|
||||||
|
|
||||||
|
|
||||||
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str):
|
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, _tokenizer_id: str, output_path: str):
|
||||||
# Break down orbax ckpts - they are in OCDBT
|
# Break down orbax ckpts - they are in OCDBT
|
||||||
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
|
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
|
||||||
# process projection params
|
# process projection params
|
||||||
@@ -432,6 +432,6 @@ if __name__ == "__main__":
|
|||||||
convert_pi0_checkpoint(
|
convert_pi0_checkpoint(
|
||||||
checkpoint_dir=args.checkpoint_dir,
|
checkpoint_dir=args.checkpoint_dir,
|
||||||
precision=args.precision,
|
precision=args.precision,
|
||||||
tokenizer_id=args.tokenizer_hub_id,
|
_tokenizer_id=args.tokenizer_hub_id,
|
||||||
output_path=args.output_path,
|
output_path=args.output_path,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import torch
|
|||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
|
|
||||||
|
# TODO(Steven): Consider settings this a dependency constraint
|
||||||
if Version(torch.__version__) > Version("2.5.0"):
|
if Version(torch.__version__) > Version("2.5.0"):
|
||||||
# Ffex attention is only available from torch 2.5 onwards
|
# Ffex attention is only available from torch 2.5 onwards
|
||||||
from torch.nn.attention.flex_attention import (
|
from torch.nn.attention.flex_attention import (
|
||||||
@@ -121,7 +122,7 @@ def flex_attention_forward(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# mask is applied inside the kernel, ideally more efficiently than score_mod.
|
# mask is applied inside the kernel, ideally more efficiently than score_mod.
|
||||||
attn_output, attention_weights = flex_attention(
|
attn_output, _attention_weights = flex_attention(
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ class TDMPCConfig(PreTrainedConfig):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
"""Input validation (not exhaustive)."""
|
# Input validation (not exhaustive).
|
||||||
if self.n_gaussian_samples <= 0:
|
if self.n_gaussian_samples <= 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The number of gaussian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
|
f"The number of gaussian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
|
||||||
|
|||||||
@@ -88,6 +88,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
|||||||
for param in self.model_target.parameters():
|
for param in self.model_target.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
|
self._queues = None
|
||||||
|
self._prev_mean: torch.Tensor | None = None
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def get_optim_params(self) -> dict:
|
def get_optim_params(self) -> dict:
|
||||||
@@ -108,7 +111,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
|||||||
self._queues["observation.environment_state"] = deque(maxlen=1)
|
self._queues["observation.environment_state"] = deque(maxlen=1)
|
||||||
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
|
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
|
||||||
# CEM for the next step.
|
# CEM for the next step.
|
||||||
self._prev_mean: torch.Tensor | None = None
|
self._prev_mean = None
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
@@ -514,6 +517,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
|||||||
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
|
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Steven): forward implementation missing
|
||||||
class TDMPCTOLD(nn.Module):
|
class TDMPCTOLD(nn.Module):
|
||||||
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
|
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
|
||||||
|
|
||||||
|
|||||||
@@ -144,7 +144,7 @@ class VQBeTConfig(PreTrainedConfig):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
"""Input validation (not exhaustive)."""
|
# Input validation (not exhaustive).
|
||||||
if not self.vision_backbone.startswith("resnet"):
|
if not self.vision_backbone.startswith("resnet"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||||
|
|||||||
@@ -70,6 +70,8 @@ class VQBeTPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
self.vqbet = VQBeTModel(config)
|
self.vqbet = VQBeTModel(config)
|
||||||
|
|
||||||
|
self._queues = None
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def get_optim_params(self) -> dict:
|
def get_optim_params(self) -> dict:
|
||||||
@@ -535,7 +537,7 @@ class VQBeTHead(nn.Module):
|
|||||||
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers
|
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers
|
||||||
)
|
)
|
||||||
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
|
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
|
||||||
NT, G, choices = cbet_probs.shape
|
NT, _G, choices = cbet_probs.shape
|
||||||
sampled_centers = einops.rearrange(
|
sampled_centers = einops.rearrange(
|
||||||
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
|
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
|
||||||
"(NT G) 1 -> NT G",
|
"(NT G) 1 -> NT G",
|
||||||
@@ -578,7 +580,7 @@ class VQBeTHead(nn.Module):
|
|||||||
"decoded_action": decoded_action,
|
"decoded_action": decoded_action,
|
||||||
}
|
}
|
||||||
|
|
||||||
def loss_fn(self, pred, target, **kwargs):
|
def loss_fn(self, pred, target, **_kwargs):
|
||||||
"""
|
"""
|
||||||
for given ground truth action values (target), and prediction (pred) this function calculates the overall loss.
|
for given ground truth action values (target), and prediction (pred) this function calculates the overall loss.
|
||||||
|
|
||||||
@@ -605,7 +607,7 @@ class VQBeTHead(nn.Module):
|
|||||||
# Figure out the loss for the actions.
|
# Figure out the loss for the actions.
|
||||||
# First, we need to find the closest cluster center for each ground truth action.
|
# First, we need to find the closest cluster center for each ground truth action.
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
|
_state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
|
||||||
|
|
||||||
# Now we can compute the loss.
|
# Now we can compute the loss.
|
||||||
|
|
||||||
@@ -762,6 +764,7 @@ def _replace_submodules(
|
|||||||
return root_module
|
return root_module
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Steven): Missing implementation of forward, is it maybe vqvae_forward?
|
||||||
class VqVae(nn.Module):
|
class VqVae(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -876,13 +879,13 @@ class FocalLoss(nn.Module):
|
|||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
self.size_average = size_average
|
self.size_average = size_average
|
||||||
|
|
||||||
def forward(self, input, target):
|
def forward(self, forward_input, target):
|
||||||
if len(input.shape) == 3:
|
if len(forward_input.shape) == 3:
|
||||||
N, T, _ = input.shape
|
N, T, _ = forward_input.shape
|
||||||
logpt = F.log_softmax(input, dim=-1)
|
logpt = F.log_softmax(forward_input, dim=-1)
|
||||||
logpt = logpt.gather(-1, target.view(N, T, 1)).view(N, T)
|
logpt = logpt.gather(-1, target.view(N, T, 1)).view(N, T)
|
||||||
elif len(input.shape) == 2:
|
elif len(forward_input.shape) == 2:
|
||||||
logpt = F.log_softmax(input, dim=-1)
|
logpt = F.log_softmax(forward_input, dim=-1)
|
||||||
logpt = logpt.gather(-1, target.view(-1, 1)).view(-1)
|
logpt = logpt.gather(-1, target.view(-1, 1)).view(-1)
|
||||||
pt = logpt.exp()
|
pt = logpt.exp()
|
||||||
|
|
||||||
|
|||||||
@@ -34,63 +34,58 @@ from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
|||||||
|
|
||||||
# ruff: noqa: N806
|
# ruff: noqa: N806
|
||||||
|
|
||||||
"""
|
# This file is part of a VQ-BeT that utilizes code from the following repositories:
|
||||||
This file is part of a VQ-BeT that utilizes code from the following repositories:
|
#
|
||||||
|
# - Vector Quantize PyTorch code is licensed under the MIT License:
|
||||||
|
# Original source: https://github.com/lucidrains/vector-quantize-pytorch
|
||||||
|
#
|
||||||
|
# - nanoGPT part is an adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch.
|
||||||
|
# Original source: https://github.com/karpathy/nanoGPT
|
||||||
|
#
|
||||||
|
# We also made some changes to the original code to adapt it to our needs. The changes are described in the code below.
|
||||||
|
|
||||||
- Vector Quantize PyTorch code is licensed under the MIT License:
|
# This is a part for nanoGPT that utilizes code from the following repository:
|
||||||
Original source: https://github.com/lucidrains/vector-quantize-pytorch
|
#
|
||||||
|
# - Andrej Karpathy's nanoGPT implementation in PyTorch.
|
||||||
- nanoGPT part is an adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch.
|
# Original source: https://github.com/karpathy/nanoGPT
|
||||||
Original source: https://github.com/karpathy/nanoGPT
|
#
|
||||||
|
# - The nanoGPT code is licensed under the MIT License:
|
||||||
We also made some changes to the original code to adapt it to our needs. The changes are described in the code below.
|
#
|
||||||
"""
|
# MIT License
|
||||||
|
#
|
||||||
"""
|
# Copyright (c) 2022 Andrej Karpathy
|
||||||
This is a part for nanoGPT that utilizes code from the following repository:
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
- Andrej Karpathy's nanoGPT implementation in PyTorch.
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
Original source: https://github.com/karpathy/nanoGPT
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
- The nanoGPT code is licensed under the MIT License:
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
MIT License
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
Copyright (c) 2022 Andrej Karpathy
|
# copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
in the Software without restriction, including without limitation the rights
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
furnished to do so, subject to the following conditions:
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
# SOFTWARE.
|
||||||
The above copyright notice and this permission notice shall be included in all
|
#
|
||||||
copies or substantial portions of the Software.
|
# - We've made some changes to the original code to adapt it to our needs.
|
||||||
|
#
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
# Changed variable names:
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
# - n_head -> gpt_n_head
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
# - n_embd -> gpt_hidden_dim
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
# - block_size -> gpt_block_size
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
# - n_layer -> gpt_n_layer
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
#
|
||||||
SOFTWARE.
|
#
|
||||||
|
# class GPT(nn.Module):
|
||||||
- We've made some changes to the original code to adapt it to our needs.
|
# - removed unused functions `def generate`, `def estimate_mfu`, and `def from_pretrained`
|
||||||
|
# - changed the `configure_optimizers` to `def configure_parameters` and made it to return only the parameters of the model: we use an external optimizer in our training loop.
|
||||||
Changed variable names:
|
# - in the function `forward`, we removed target loss calculation parts, since it will be calculated in the training loop (after passing through bin prediction and offset prediction heads).
|
||||||
- n_head -> gpt_n_head
|
|
||||||
- n_embd -> gpt_hidden_dim
|
|
||||||
- block_size -> gpt_block_size
|
|
||||||
- n_layer -> gpt_n_layer
|
|
||||||
|
|
||||||
|
|
||||||
class GPT(nn.Module):
|
|
||||||
- removed unused functions `def generate`, `def estimate_mfu`, and `def from_pretrained`
|
|
||||||
- changed the `configure_optimizers` to `def configure_parameters` and made it to return only the parameters of the model: we use an external optimizer in our training loop.
|
|
||||||
- in the function `forward`, we removed target loss calculation parts, since it will be calculated in the training loop (after passing through bin prediction and offset prediction heads).
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class CausalSelfAttention(nn.Module):
|
class CausalSelfAttention(nn.Module):
|
||||||
@@ -200,9 +195,9 @@ class GPT(nn.Module):
|
|||||||
n_params = sum(p.numel() for p in self.parameters())
|
n_params = sum(p.numel() for p in self.parameters())
|
||||||
print("number of parameters: {:.2f}M".format(n_params / 1e6))
|
print("number of parameters: {:.2f}M".format(n_params / 1e6))
|
||||||
|
|
||||||
def forward(self, input, targets=None):
|
def forward(self, forward_input):
|
||||||
device = input.device
|
device = forward_input.device
|
||||||
b, t, d = input.size()
|
_, t, _ = forward_input.size()
|
||||||
assert t <= self.config.gpt_block_size, (
|
assert t <= self.config.gpt_block_size, (
|
||||||
f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
|
f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
|
||||||
)
|
)
|
||||||
@@ -211,7 +206,7 @@ class GPT(nn.Module):
|
|||||||
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
|
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
|
||||||
|
|
||||||
# forward the GPT model itself
|
# forward the GPT model itself
|
||||||
tok_emb = self.transformer.wte(input) # token embeddings of shape (b, t, gpt_hidden_dim)
|
tok_emb = self.transformer.wte(forward_input) # token embeddings of shape (b, t, gpt_hidden_dim)
|
||||||
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim)
|
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim)
|
||||||
x = self.transformer.drop(tok_emb + pos_emb)
|
x = self.transformer.drop(tok_emb + pos_emb)
|
||||||
for block in self.transformer.h:
|
for block in self.transformer.h:
|
||||||
@@ -285,51 +280,48 @@ class GPT(nn.Module):
|
|||||||
return decay, no_decay
|
return decay, no_decay
|
||||||
|
|
||||||
|
|
||||||
"""
|
# This file is a part for Residual Vector Quantization that utilizes code from the following repository:
|
||||||
This file is a part for Residual Vector Quantization that utilizes code from the following repository:
|
#
|
||||||
|
# - Phil Wang's vector-quantize-pytorch implementation in PyTorch.
|
||||||
- Phil Wang's vector-quantize-pytorch implementation in PyTorch.
|
# Original source: https://github.com/lucidrains/vector-quantize-pytorch
|
||||||
Original source: https://github.com/lucidrains/vector-quantize-pytorch
|
#
|
||||||
|
# - The vector-quantize-pytorch code is licensed under the MIT License:
|
||||||
- The vector-quantize-pytorch code is licensed under the MIT License:
|
#
|
||||||
|
# MIT License
|
||||||
MIT License
|
#
|
||||||
|
# Copyright (c) 2020 Phil Wang
|
||||||
Copyright (c) 2020 Phil Wang
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
# in the Software without restriction, including without limitation the rights
|
||||||
in the Software without restriction, including without limitation the rights
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
# furnished to do so, subject to the following conditions:
|
||||||
furnished to do so, subject to the following conditions:
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
The above copyright notice and this permission notice shall be included in all
|
# copies or substantial portions of the Software.
|
||||||
copies or substantial portions of the Software.
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
# SOFTWARE.
|
||||||
SOFTWARE.
|
#
|
||||||
|
# - We've made some changes to the original code to adapt it to our needs.
|
||||||
- We've made some changes to the original code to adapt it to our needs.
|
#
|
||||||
|
# class ResidualVQ(nn.Module):
|
||||||
class ResidualVQ(nn.Module):
|
# - added `self.register_buffer('freeze_codebook', torch.tensor(False))` to the __init__ method:
|
||||||
- added `self.register_buffer('freeze_codebook', torch.tensor(False))` to the __init__ method:
|
# This enables the user to save an indicator whether the codebook is frozen or not.
|
||||||
This enables the user to save an indicator whether the codebook is frozen or not.
|
# - changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`:
|
||||||
- changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`:
|
# This is to make the function name more descriptive.
|
||||||
This is to make the function name more descriptive.
|
#
|
||||||
|
# class VectorQuantize(nn.Module):
|
||||||
class VectorQuantize(nn.Module):
|
# - removed the `use_cosine_sim` and `layernorm_after_project_in` parameters from the __init__ method:
|
||||||
- removed the `use_cosine_sim` and `layernorm_after_project_in` parameters from the __init__ method:
|
# These parameters are not used in the code.
|
||||||
These parameters are not used in the code.
|
# - changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`:
|
||||||
- changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`:
|
# This is to make the function name more descriptive.
|
||||||
This is to make the function name more descriptive.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualVQ(nn.Module):
|
class ResidualVQ(nn.Module):
|
||||||
@@ -479,6 +471,9 @@ class ResidualVQ(nn.Module):
|
|||||||
|
|
||||||
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
|
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
|
||||||
|
|
||||||
|
null_indices = None
|
||||||
|
null_loss = None
|
||||||
|
|
||||||
# sample a layer index at which to dropout further residual quantization
|
# sample a layer index at which to dropout further residual quantization
|
||||||
# also prepare null indices and loss
|
# also prepare null indices and loss
|
||||||
|
|
||||||
@@ -933,7 +928,7 @@ class VectorQuantize(nn.Module):
|
|||||||
return quantize, embed_ind, loss
|
return quantize, embed_ind, loss
|
||||||
|
|
||||||
|
|
||||||
def noop(*args, **kwargs):
|
def noop(*_args, **_kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -77,9 +77,9 @@ def save_image(img_array, serial_number, frame_index, images_dir):
|
|||||||
path = images_dir / f"camera_{serial_number}_frame_{frame_index:06d}.png"
|
path = images_dir / f"camera_{serial_number}_frame_{frame_index:06d}.png"
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
img.save(str(path), quality=100)
|
img.save(str(path), quality=100)
|
||||||
logging.info(f"Saved image: {path}")
|
logging.info("Saved image: %s", path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Failed to save image for camera {serial_number} frame {frame_index}: {e}")
|
logging.error("Failed to save image for camera %s frame %s: %s", serial_number, frame_index, e)
|
||||||
|
|
||||||
|
|
||||||
def save_images_from_cameras(
|
def save_images_from_cameras(
|
||||||
@@ -447,7 +447,7 @@ class IntelRealSenseCamera:
|
|||||||
num_tries += 1
|
num_tries += 1
|
||||||
time.sleep(1 / self.fps)
|
time.sleep(1 / self.fps)
|
||||||
if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()):
|
if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()):
|
||||||
raise Exception(
|
raise TimeoutError(
|
||||||
"The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called."
|
"The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ from lerobot.common.utils.utils import capture_timestamp_utc
|
|||||||
MAX_OPENCV_INDEX = 60
|
MAX_OPENCV_INDEX = 60
|
||||||
|
|
||||||
|
|
||||||
def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
|
def find_cameras(max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
|
||||||
cameras = []
|
cameras = []
|
||||||
if platform.system() == "Linux":
|
if platform.system() == "Linux":
|
||||||
print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports")
|
print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports")
|
||||||
|
|||||||
@@ -169,7 +169,8 @@ def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str
|
|||||||
return steps
|
return steps
|
||||||
|
|
||||||
|
|
||||||
def convert_to_bytes(value, bytes, mock=False):
|
# TODO(Steven): Similar function in feetch.py, should be moved to a common place.
|
||||||
|
def convert_to_bytes(value, byte, mock=False):
|
||||||
if mock:
|
if mock:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@@ -177,16 +178,16 @@ def convert_to_bytes(value, bytes, mock=False):
|
|||||||
|
|
||||||
# Note: No need to convert back into unsigned int, since this byte preprocessing
|
# Note: No need to convert back into unsigned int, since this byte preprocessing
|
||||||
# already handles it for us.
|
# already handles it for us.
|
||||||
if bytes == 1:
|
if byte == 1:
|
||||||
data = [
|
data = [
|
||||||
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
|
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
|
||||||
]
|
]
|
||||||
elif bytes == 2:
|
elif byte == 2:
|
||||||
data = [
|
data = [
|
||||||
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
|
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
|
||||||
dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
|
dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
|
||||||
]
|
]
|
||||||
elif bytes == 4:
|
elif byte == 4:
|
||||||
data = [
|
data = [
|
||||||
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
|
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
|
||||||
dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
|
dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
|
||||||
@@ -196,7 +197,7 @@ def convert_to_bytes(value, bytes, mock=False):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
|
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
|
||||||
f"{bytes} is provided instead."
|
f"{byte} is provided instead."
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@@ -228,9 +229,9 @@ def assert_same_address(model_ctrl_table, motor_models, data_name):
|
|||||||
all_addr = []
|
all_addr = []
|
||||||
all_bytes = []
|
all_bytes = []
|
||||||
for model in motor_models:
|
for model in motor_models:
|
||||||
addr, bytes = model_ctrl_table[model][data_name]
|
addr, byte = model_ctrl_table[model][data_name]
|
||||||
all_addr.append(addr)
|
all_addr.append(addr)
|
||||||
all_bytes.append(bytes)
|
all_bytes.append(byte)
|
||||||
|
|
||||||
if len(set(all_addr)) != 1:
|
if len(set(all_addr)) != 1:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@@ -576,6 +577,8 @@ class DynamixelMotorsBus:
|
|||||||
# (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution
|
# (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution
|
||||||
low_factor = (start_pos - values[i]) / resolution
|
low_factor = (start_pos - values[i]) / resolution
|
||||||
upp_factor = (end_pos - values[i]) / resolution
|
upp_factor = (end_pos - values[i]) / resolution
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown calibration mode '{calib_mode}'.")
|
||||||
|
|
||||||
if not in_range:
|
if not in_range:
|
||||||
# Get first integer between the two bounds
|
# Get first integer between the two bounds
|
||||||
@@ -596,10 +599,15 @@ class DynamixelMotorsBus:
|
|||||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||||
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||||
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown calibration mode '{calib_mode}'.")
|
||||||
|
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
|
"Auto-correct calibration of motor '%s' by shifting value by {abs(factor)} full turns, "
|
||||||
f"from '{out_of_range_str}' to '{in_range_str}'."
|
"from '%s' to '%s'.",
|
||||||
|
name,
|
||||||
|
out_of_range_str,
|
||||||
|
in_range_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
||||||
@@ -656,8 +664,8 @@ class DynamixelMotorsBus:
|
|||||||
motor_ids = [motor_ids]
|
motor_ids = [motor_ids]
|
||||||
|
|
||||||
assert_same_address(self.model_ctrl_table, self.motor_models, data_name)
|
assert_same_address(self.model_ctrl_table, self.motor_models, data_name)
|
||||||
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
addr, byte = self.model_ctrl_table[motor_models[0]][data_name]
|
||||||
group = dxl.GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes)
|
group = dxl.GroupSyncRead(self.port_handler, self.packet_handler, addr, byte)
|
||||||
for idx in motor_ids:
|
for idx in motor_ids:
|
||||||
group.addParam(idx)
|
group.addParam(idx)
|
||||||
|
|
||||||
@@ -674,7 +682,7 @@ class DynamixelMotorsBus:
|
|||||||
|
|
||||||
values = []
|
values = []
|
||||||
for idx in motor_ids:
|
for idx in motor_ids:
|
||||||
value = group.getData(idx, addr, bytes)
|
value = group.getData(idx, addr, byte)
|
||||||
values.append(value)
|
values.append(value)
|
||||||
|
|
||||||
if return_list:
|
if return_list:
|
||||||
@@ -709,13 +717,13 @@ class DynamixelMotorsBus:
|
|||||||
models.append(model)
|
models.append(model)
|
||||||
|
|
||||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||||
addr, bytes = self.model_ctrl_table[model][data_name]
|
addr, byte = self.model_ctrl_table[model][data_name]
|
||||||
group_key = get_group_sync_key(data_name, motor_names)
|
group_key = get_group_sync_key(data_name, motor_names)
|
||||||
|
|
||||||
if data_name not in self.group_readers:
|
if data_name not in self.group_readers:
|
||||||
# create new group reader
|
# create new group reader
|
||||||
self.group_readers[group_key] = dxl.GroupSyncRead(
|
self.group_readers[group_key] = dxl.GroupSyncRead(
|
||||||
self.port_handler, self.packet_handler, addr, bytes
|
self.port_handler, self.packet_handler, addr, byte
|
||||||
)
|
)
|
||||||
for idx in motor_ids:
|
for idx in motor_ids:
|
||||||
self.group_readers[group_key].addParam(idx)
|
self.group_readers[group_key].addParam(idx)
|
||||||
@@ -733,7 +741,7 @@ class DynamixelMotorsBus:
|
|||||||
|
|
||||||
values = []
|
values = []
|
||||||
for idx in motor_ids:
|
for idx in motor_ids:
|
||||||
value = self.group_readers[group_key].getData(idx, addr, bytes)
|
value = self.group_readers[group_key].getData(idx, addr, byte)
|
||||||
values.append(value)
|
values.append(value)
|
||||||
|
|
||||||
values = np.array(values)
|
values = np.array(values)
|
||||||
@@ -767,10 +775,10 @@ class DynamixelMotorsBus:
|
|||||||
values = [values]
|
values = [values]
|
||||||
|
|
||||||
assert_same_address(self.model_ctrl_table, motor_models, data_name)
|
assert_same_address(self.model_ctrl_table, motor_models, data_name)
|
||||||
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
addr, byte = self.model_ctrl_table[motor_models[0]][data_name]
|
||||||
group = dxl.GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
|
group = dxl.GroupSyncWrite(self.port_handler, self.packet_handler, addr, byte)
|
||||||
for idx, value in zip(motor_ids, values, strict=True):
|
for idx, value in zip(motor_ids, values, strict=True):
|
||||||
data = convert_to_bytes(value, bytes, self.mock)
|
data = convert_to_bytes(value, byte, self.mock)
|
||||||
group.addParam(idx, data)
|
group.addParam(idx, data)
|
||||||
|
|
||||||
for _ in range(num_retry):
|
for _ in range(num_retry):
|
||||||
@@ -821,17 +829,17 @@ class DynamixelMotorsBus:
|
|||||||
values = values.tolist()
|
values = values.tolist()
|
||||||
|
|
||||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||||
addr, bytes = self.model_ctrl_table[model][data_name]
|
addr, byte = self.model_ctrl_table[model][data_name]
|
||||||
group_key = get_group_sync_key(data_name, motor_names)
|
group_key = get_group_sync_key(data_name, motor_names)
|
||||||
|
|
||||||
init_group = data_name not in self.group_readers
|
init_group = data_name not in self.group_readers
|
||||||
if init_group:
|
if init_group:
|
||||||
self.group_writers[group_key] = dxl.GroupSyncWrite(
|
self.group_writers[group_key] = dxl.GroupSyncWrite(
|
||||||
self.port_handler, self.packet_handler, addr, bytes
|
self.port_handler, self.packet_handler, addr, byte
|
||||||
)
|
)
|
||||||
|
|
||||||
for idx, value in zip(motor_ids, values, strict=True):
|
for idx, value in zip(motor_ids, values, strict=True):
|
||||||
data = convert_to_bytes(value, bytes, self.mock)
|
data = convert_to_bytes(value, byte, self.mock)
|
||||||
if init_group:
|
if init_group:
|
||||||
self.group_writers[group_key].addParam(idx, data)
|
self.group_writers[group_key].addParam(idx, data)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -148,7 +148,7 @@ def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str
|
|||||||
return steps
|
return steps
|
||||||
|
|
||||||
|
|
||||||
def convert_to_bytes(value, bytes, mock=False):
|
def convert_to_bytes(value, byte, mock=False):
|
||||||
if mock:
|
if mock:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@@ -156,16 +156,16 @@ def convert_to_bytes(value, bytes, mock=False):
|
|||||||
|
|
||||||
# Note: No need to convert back into unsigned int, since this byte preprocessing
|
# Note: No need to convert back into unsigned int, since this byte preprocessing
|
||||||
# already handles it for us.
|
# already handles it for us.
|
||||||
if bytes == 1:
|
if byte == 1:
|
||||||
data = [
|
data = [
|
||||||
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
|
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
|
||||||
]
|
]
|
||||||
elif bytes == 2:
|
elif byte == 2:
|
||||||
data = [
|
data = [
|
||||||
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
|
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
|
||||||
scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
|
scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
|
||||||
]
|
]
|
||||||
elif bytes == 4:
|
elif byte == 4:
|
||||||
data = [
|
data = [
|
||||||
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
|
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
|
||||||
scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
|
scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
|
||||||
@@ -175,7 +175,7 @@ def convert_to_bytes(value, bytes, mock=False):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
|
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
|
||||||
f"{bytes} is provided instead."
|
f"{byte} is provided instead."
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@@ -207,9 +207,9 @@ def assert_same_address(model_ctrl_table, motor_models, data_name):
|
|||||||
all_addr = []
|
all_addr = []
|
||||||
all_bytes = []
|
all_bytes = []
|
||||||
for model in motor_models:
|
for model in motor_models:
|
||||||
addr, bytes = model_ctrl_table[model][data_name]
|
addr, byte = model_ctrl_table[model][data_name]
|
||||||
all_addr.append(addr)
|
all_addr.append(addr)
|
||||||
all_bytes.append(bytes)
|
all_bytes.append(byte)
|
||||||
|
|
||||||
if len(set(all_addr)) != 1:
|
if len(set(all_addr)) != 1:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@@ -557,6 +557,8 @@ class FeetechMotorsBus:
|
|||||||
# (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution
|
# (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution
|
||||||
low_factor = (start_pos - values[i]) / resolution
|
low_factor = (start_pos - values[i]) / resolution
|
||||||
upp_factor = (end_pos - values[i]) / resolution
|
upp_factor = (end_pos - values[i]) / resolution
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown calibration mode {calib_mode}")
|
||||||
|
|
||||||
if not in_range:
|
if not in_range:
|
||||||
# Get first integer between the two bounds
|
# Get first integer between the two bounds
|
||||||
@@ -577,10 +579,16 @@ class FeetechMotorsBus:
|
|||||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||||
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||||
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown calibration mode {calib_mode}")
|
||||||
|
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
|
"Auto-correct calibration of motor '%s' by shifting value by %s full turns, "
|
||||||
f"from '{out_of_range_str}' to '{in_range_str}'."
|
"from '%s' to '%s'.",
|
||||||
|
name,
|
||||||
|
abs(factor),
|
||||||
|
out_of_range_str,
|
||||||
|
in_range_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
||||||
@@ -674,8 +682,8 @@ class FeetechMotorsBus:
|
|||||||
motor_ids = [motor_ids]
|
motor_ids = [motor_ids]
|
||||||
|
|
||||||
assert_same_address(self.model_ctrl_table, self.motor_models, data_name)
|
assert_same_address(self.model_ctrl_table, self.motor_models, data_name)
|
||||||
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
addr, byte = self.model_ctrl_table[motor_models[0]][data_name]
|
||||||
group = scs.GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes)
|
group = scs.GroupSyncRead(self.port_handler, self.packet_handler, addr, byte)
|
||||||
for idx in motor_ids:
|
for idx in motor_ids:
|
||||||
group.addParam(idx)
|
group.addParam(idx)
|
||||||
|
|
||||||
@@ -692,7 +700,7 @@ class FeetechMotorsBus:
|
|||||||
|
|
||||||
values = []
|
values = []
|
||||||
for idx in motor_ids:
|
for idx in motor_ids:
|
||||||
value = group.getData(idx, addr, bytes)
|
value = group.getData(idx, addr, byte)
|
||||||
values.append(value)
|
values.append(value)
|
||||||
|
|
||||||
if return_list:
|
if return_list:
|
||||||
@@ -727,7 +735,7 @@ class FeetechMotorsBus:
|
|||||||
models.append(model)
|
models.append(model)
|
||||||
|
|
||||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||||
addr, bytes = self.model_ctrl_table[model][data_name]
|
addr, byte = self.model_ctrl_table[model][data_name]
|
||||||
group_key = get_group_sync_key(data_name, motor_names)
|
group_key = get_group_sync_key(data_name, motor_names)
|
||||||
|
|
||||||
if data_name not in self.group_readers:
|
if data_name not in self.group_readers:
|
||||||
@@ -737,7 +745,7 @@ class FeetechMotorsBus:
|
|||||||
|
|
||||||
# create new group reader
|
# create new group reader
|
||||||
self.group_readers[group_key] = scs.GroupSyncRead(
|
self.group_readers[group_key] = scs.GroupSyncRead(
|
||||||
self.port_handler, self.packet_handler, addr, bytes
|
self.port_handler, self.packet_handler, addr, byte
|
||||||
)
|
)
|
||||||
for idx in motor_ids:
|
for idx in motor_ids:
|
||||||
self.group_readers[group_key].addParam(idx)
|
self.group_readers[group_key].addParam(idx)
|
||||||
@@ -755,7 +763,7 @@ class FeetechMotorsBus:
|
|||||||
|
|
||||||
values = []
|
values = []
|
||||||
for idx in motor_ids:
|
for idx in motor_ids:
|
||||||
value = self.group_readers[group_key].getData(idx, addr, bytes)
|
value = self.group_readers[group_key].getData(idx, addr, byte)
|
||||||
values.append(value)
|
values.append(value)
|
||||||
|
|
||||||
values = np.array(values)
|
values = np.array(values)
|
||||||
@@ -792,10 +800,10 @@ class FeetechMotorsBus:
|
|||||||
values = [values]
|
values = [values]
|
||||||
|
|
||||||
assert_same_address(self.model_ctrl_table, motor_models, data_name)
|
assert_same_address(self.model_ctrl_table, motor_models, data_name)
|
||||||
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
addr, byte = self.model_ctrl_table[motor_models[0]][data_name]
|
||||||
group = scs.GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
|
group = scs.GroupSyncWrite(self.port_handler, self.packet_handler, addr, byte)
|
||||||
for idx, value in zip(motor_ids, values, strict=True):
|
for idx, value in zip(motor_ids, values, strict=True):
|
||||||
data = convert_to_bytes(value, bytes, self.mock)
|
data = convert_to_bytes(value, byte, self.mock)
|
||||||
group.addParam(idx, data)
|
group.addParam(idx, data)
|
||||||
|
|
||||||
for _ in range(num_retry):
|
for _ in range(num_retry):
|
||||||
@@ -846,17 +854,17 @@ class FeetechMotorsBus:
|
|||||||
values = values.tolist()
|
values = values.tolist()
|
||||||
|
|
||||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||||
addr, bytes = self.model_ctrl_table[model][data_name]
|
addr, byte = self.model_ctrl_table[model][data_name]
|
||||||
group_key = get_group_sync_key(data_name, motor_names)
|
group_key = get_group_sync_key(data_name, motor_names)
|
||||||
|
|
||||||
init_group = data_name not in self.group_readers
|
init_group = data_name not in self.group_readers
|
||||||
if init_group:
|
if init_group:
|
||||||
self.group_writers[group_key] = scs.GroupSyncWrite(
|
self.group_writers[group_key] = scs.GroupSyncWrite(
|
||||||
self.port_handler, self.packet_handler, addr, bytes
|
self.port_handler, self.packet_handler, addr, byte
|
||||||
)
|
)
|
||||||
|
|
||||||
for idx, value in zip(motor_ids, values, strict=True):
|
for idx, value in zip(motor_ids, values, strict=True):
|
||||||
data = convert_to_bytes(value, bytes, self.mock)
|
data = convert_to_bytes(value, byte, self.mock)
|
||||||
if init_group:
|
if init_group:
|
||||||
self.group_writers[group_key].addParam(idx, data)
|
self.group_writers[group_key].addParam(idx, data)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -95,6 +95,8 @@ def move_to_calibrate(
|
|||||||
while_move_hook=None,
|
while_move_hook=None,
|
||||||
):
|
):
|
||||||
initial_pos = arm.read("Present_Position", motor_name)
|
initial_pos = arm.read("Present_Position", motor_name)
|
||||||
|
p_present_pos = None
|
||||||
|
n_present_pos = None
|
||||||
|
|
||||||
if positive_first:
|
if positive_first:
|
||||||
p_present_pos = move_until_block(
|
p_present_pos = move_until_block(
|
||||||
@@ -196,7 +198,7 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
|
|||||||
calib["wrist_flex"] = move_to_calibrate(arm, "wrist_flex")
|
calib["wrist_flex"] = move_to_calibrate(arm, "wrist_flex")
|
||||||
calib["wrist_flex"] = apply_offset(calib["wrist_flex"], offset=80)
|
calib["wrist_flex"] = apply_offset(calib["wrist_flex"], offset=80)
|
||||||
|
|
||||||
def in_between_move_hook():
|
def in_between_move_hook_elbow():
|
||||||
nonlocal arm, calib
|
nonlocal arm, calib
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
ef_pos = arm.read("Present_Position", "elbow_flex")
|
ef_pos = arm.read("Present_Position", "elbow_flex")
|
||||||
@@ -207,14 +209,14 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
|
|||||||
|
|
||||||
print("Calibrate elbow_flex")
|
print("Calibrate elbow_flex")
|
||||||
calib["elbow_flex"] = move_to_calibrate(
|
calib["elbow_flex"] = move_to_calibrate(
|
||||||
arm, "elbow_flex", positive_first=False, in_between_move_hook=in_between_move_hook
|
arm, "elbow_flex", positive_first=False, in_between_move_hook=in_between_move_hook_elbow
|
||||||
)
|
)
|
||||||
calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024)
|
calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024)
|
||||||
|
|
||||||
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex")
|
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex")
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
def in_between_move_hook():
|
def in_between_move_hook_shoulder():
|
||||||
nonlocal arm, calib
|
nonlocal arm, calib
|
||||||
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"], "elbow_flex")
|
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"], "elbow_flex")
|
||||||
|
|
||||||
@@ -224,7 +226,7 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
|
|||||||
"shoulder_lift",
|
"shoulder_lift",
|
||||||
invert_drive_mode=True,
|
invert_drive_mode=True,
|
||||||
positive_first=False,
|
positive_first=False,
|
||||||
in_between_move_hook=in_between_move_hook,
|
in_between_move_hook=in_between_move_hook_shoulder,
|
||||||
)
|
)
|
||||||
# add an 30 steps as offset to align with body
|
# add an 30 steps as offset to align with body
|
||||||
calib["shoulder_lift"] = apply_offset(calib["shoulder_lift"], offset=1024 - 50)
|
calib["shoulder_lift"] = apply_offset(calib["shoulder_lift"], offset=1024 - 50)
|
||||||
|
|||||||
@@ -67,14 +67,14 @@ def calibrate_follower_arm(motors_bus, calib_dir_str):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if calib_file.exists():
|
if calib_file.exists():
|
||||||
with open(calib_file) as f:
|
with open(calib_file, encoding="utf-8") as f:
|
||||||
calibration = json.load(f)
|
calibration = json.load(f)
|
||||||
print(f"[INFO] Loaded calibration from {calib_file}")
|
print(f"[INFO] Loaded calibration from {calib_file}")
|
||||||
else:
|
else:
|
||||||
print("[INFO] Calibration file not found. Running manual calibration...")
|
print("[INFO] Calibration file not found. Running manual calibration...")
|
||||||
calibration = run_arm_manual_calibration(motors_bus, "lekiwi", "follower_arm", "follower")
|
calibration = run_arm_manual_calibration(motors_bus, "lekiwi", "follower_arm", "follower")
|
||||||
print(f"[INFO] Calibration complete. Saving to {calib_file}")
|
print(f"[INFO] Calibration complete. Saving to {calib_file}")
|
||||||
with open(calib_file, "w") as f:
|
with open(calib_file, "w", encoding="utf-8") as f:
|
||||||
json.dump(calibration, f)
|
json.dump(calibration, f)
|
||||||
try:
|
try:
|
||||||
motors_bus.set_calibration(calibration)
|
motors_bus.set_calibration(calibration)
|
||||||
|
|||||||
@@ -47,8 +47,10 @@ def ensure_safe_goal_position(
|
|||||||
if not torch.allclose(goal_pos, safe_goal_pos):
|
if not torch.allclose(goal_pos, safe_goal_pos):
|
||||||
logging.warning(
|
logging.warning(
|
||||||
"Relative goal position magnitude had to be clamped to be safe.\n"
|
"Relative goal position magnitude had to be clamped to be safe.\n"
|
||||||
f" requested relative goal position target: {diff}\n"
|
" requested relative goal position target: %s\n"
|
||||||
f" clamped relative goal position target: {safe_diff}"
|
" clamped relative goal position target: %s",
|
||||||
|
diff,
|
||||||
|
safe_diff,
|
||||||
)
|
)
|
||||||
|
|
||||||
return safe_goal_pos
|
return safe_goal_pos
|
||||||
@@ -245,6 +247,8 @@ class ManipulatorRobot:
|
|||||||
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
|
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
|
||||||
elif self.robot_type in ["so100", "moss", "lekiwi"]:
|
elif self.robot_type in ["so100", "moss", "lekiwi"]:
|
||||||
from lerobot.common.robot_devices.motors.feetech import TorqueMode
|
from lerobot.common.robot_devices.motors.feetech import TorqueMode
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Robot type {self.robot_type} is not supported")
|
||||||
|
|
||||||
# We assume that at connection time, arms are in a rest position, and torque can
|
# We assume that at connection time, arms are in a rest position, and torque can
|
||||||
# be safely disabled to run calibration and/or set robot preset configurations.
|
# be safely disabled to run calibration and/or set robot preset configurations.
|
||||||
@@ -302,7 +306,7 @@ class ManipulatorRobot:
|
|||||||
arm_calib_path = self.calibration_dir / f"{arm_id}.json"
|
arm_calib_path = self.calibration_dir / f"{arm_id}.json"
|
||||||
|
|
||||||
if arm_calib_path.exists():
|
if arm_calib_path.exists():
|
||||||
with open(arm_calib_path) as f:
|
with open(arm_calib_path, encoding="utf-8") as f:
|
||||||
calibration = json.load(f)
|
calibration = json.load(f)
|
||||||
else:
|
else:
|
||||||
# TODO(rcadene): display a warning in __init__ if calibration file not available
|
# TODO(rcadene): display a warning in __init__ if calibration file not available
|
||||||
@@ -322,7 +326,7 @@ class ManipulatorRobot:
|
|||||||
|
|
||||||
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
||||||
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
with open(arm_calib_path, "w") as f:
|
with open(arm_calib_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(calibration, f)
|
json.dump(calibration, f)
|
||||||
|
|
||||||
return calibration
|
return calibration
|
||||||
|
|||||||
@@ -262,14 +262,14 @@ class MobileManipulator:
|
|||||||
arm_calib_path = self.calibration_dir / f"{arm_id}.json"
|
arm_calib_path = self.calibration_dir / f"{arm_id}.json"
|
||||||
|
|
||||||
if arm_calib_path.exists():
|
if arm_calib_path.exists():
|
||||||
with open(arm_calib_path) as f:
|
with open(arm_calib_path, encoding="utf-8") as f:
|
||||||
calibration = json.load(f)
|
calibration = json.load(f)
|
||||||
else:
|
else:
|
||||||
print(f"Missing calibration file '{arm_calib_path}'")
|
print(f"Missing calibration file '{arm_calib_path}'")
|
||||||
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
|
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
|
||||||
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
||||||
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
with open(arm_calib_path, "w") as f:
|
with open(arm_calib_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(calibration, f)
|
json.dump(calibration, f)
|
||||||
|
|
||||||
return calibration
|
return calibration
|
||||||
@@ -372,6 +372,7 @@ class MobileManipulator:
|
|||||||
|
|
||||||
present_speed = self.last_present_speed
|
present_speed = self.last_present_speed
|
||||||
|
|
||||||
|
# TODO(Steven): [WARN] Plenty of general exceptions
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[DEBUG] Error decoding video message: {e}")
|
print(f"[DEBUG] Error decoding video message: {e}")
|
||||||
# If decode fails, fall back to old data
|
# If decode fails, fall back to old data
|
||||||
|
|||||||
@@ -68,9 +68,9 @@ class TimeBenchmark(ContextDecorator):
|
|||||||
Block took approximately 10.00 milliseconds
|
Block took approximately 10.00 milliseconds
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, print=False):
|
def __init__(self, print_time=False):
|
||||||
self.local = threading.local()
|
self.local = threading.local()
|
||||||
self.print_time = print
|
self.print_time = print_time
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.local.start_time = time.perf_counter()
|
self.local.start_time = time.perf_counter()
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b
|
|||||||
else:
|
else:
|
||||||
# For packages other than "torch", don't attempt the fallback and set as not available
|
# For packages other than "torch", don't attempt the fallback and set as not available
|
||||||
package_exists = False
|
package_exists = False
|
||||||
logging.debug(f"Detected {pkg_name} version: {package_version}")
|
logging.debug("Detected %s version: %s", {pkg_name}, package_version)
|
||||||
if return_version:
|
if return_version:
|
||||||
return package_exists, package_version
|
return package_exists, package_version
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -27,6 +27,8 @@ class AverageMeter:
|
|||||||
def __init__(self, name: str, fmt: str = ":f"):
|
def __init__(self, name: str, fmt: str = ":f"):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.fmt = fmt
|
self.fmt = fmt
|
||||||
|
self.val = 0.0
|
||||||
|
self.avg = 0.0
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
|||||||
case _:
|
case _:
|
||||||
device = torch.device(try_device)
|
device = torch.device(try_device)
|
||||||
if log:
|
if log:
|
||||||
logging.warning(f"Using custom {try_device} device.")
|
logging.warning("Using custom %s device.", try_device)
|
||||||
|
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ class WandBLogger:
|
|||||||
resume="must" if cfg.resume else None,
|
resume="must" if cfg.resume else None,
|
||||||
)
|
)
|
||||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
logging.info("Track this run --> %s", colored(wandb.run.get_url(), "yellow", attrs=["bold"]))
|
||||||
self._wandb = wandb
|
self._wandb = wandb
|
||||||
|
|
||||||
def log_policy(self, checkpoint_dir: Path):
|
def log_policy(self, checkpoint_dir: Path):
|
||||||
@@ -108,7 +108,7 @@ class WandBLogger:
|
|||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
if not isinstance(v, (int, float, str)):
|
if not isinstance(v, (int, float, str)):
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
'WandB logging of key "%s" was ignored as its type is not handled by this wrapper.', k
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||||
|
|||||||
@@ -64,13 +64,14 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
|||||||
self.pretrained_path = None
|
self.pretrained_path = None
|
||||||
if not self.device or not is_torch_device_available(self.device):
|
if not self.device or not is_torch_device_available(self.device):
|
||||||
auto_device = auto_select_torch_device()
|
auto_device = auto_select_torch_device()
|
||||||
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
logging.warning("Device '%s' is not available. Switching to '%s'.", self.device, auto_device)
|
||||||
self.device = auto_device.type
|
self.device = auto_device.type
|
||||||
|
|
||||||
# Automatically deactivate AMP if necessary
|
# Automatically deactivate AMP if necessary
|
||||||
if self.use_amp and not is_amp_available(self.device):
|
if self.use_amp and not is_amp_available(self.device):
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
"Automatic Mixed Precision (amp) is not available on device '%s'. Deactivating AMP.",
|
||||||
|
self.device,
|
||||||
)
|
)
|
||||||
self.use_amp = False
|
self.use_amp = False
|
||||||
|
|
||||||
@@ -78,15 +79,18 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
|||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
return self.get_choice_name(self.__class__)
|
return self.get_choice_name(self.__class__)
|
||||||
|
|
||||||
@abc.abstractproperty
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
def observation_delta_indices(self) -> list | None:
|
def observation_delta_indices(self) -> list | None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractproperty
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
def action_delta_indices(self) -> list | None:
|
def action_delta_indices(self) -> list | None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractproperty
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
def reward_delta_indices(self) -> list | None:
|
def reward_delta_indices(self) -> list | None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -128,7 +132,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def _save_pretrained(self, save_directory: Path) -> None:
|
def _save_pretrained(self, save_directory: Path) -> None:
|
||||||
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
with open(save_directory / CONFIG_NAME, "w", encoding="utf-8") as f, draccus.config_type("json"):
|
||||||
draccus.dump(self, f, indent=4)
|
draccus.dump(self, f, indent=4)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -123,7 +123,10 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
return draccus.encode(self)
|
return draccus.encode(self)
|
||||||
|
|
||||||
def _save_pretrained(self, save_directory: Path) -> None:
|
def _save_pretrained(self, save_directory: Path) -> None:
|
||||||
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
with (
|
||||||
|
open(save_directory / TRAIN_CONFIG_NAME, "w", encoding="utf-8") as f,
|
||||||
|
draccus.config_type("json"),
|
||||||
|
):
|
||||||
draccus.dump(self, f, indent=4)
|
draccus.dump(self, f, indent=4)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -90,6 +90,7 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
|||||||
print("Scanning all baudrates and motor indices")
|
print("Scanning all baudrates and motor indices")
|
||||||
all_baudrates = set(series_baudrate_table.values())
|
all_baudrates = set(series_baudrate_table.values())
|
||||||
motor_index = -1 # Set the motor index to an out-of-range value.
|
motor_index = -1 # Set the motor index to an out-of-range value.
|
||||||
|
baudrate = None
|
||||||
|
|
||||||
for baudrate in all_baudrates:
|
for baudrate in all_baudrates:
|
||||||
motor_bus.set_bus_baudrate(baudrate)
|
motor_bus.set_bus_baudrate(baudrate)
|
||||||
|
|||||||
@@ -81,6 +81,7 @@ This might require a sudo permission to allow your terminal to monitor keyboard
|
|||||||
**NOTE**: You can resume/continue data recording by running the same data recording command twice.
|
**NOTE**: You can resume/continue data recording by running the same data recording command twice.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# TODO(Steven): This script should be updated to use the new robot API and the new dataset API.
|
||||||
import argparse
|
import argparse
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ np_version = np.__version__ if HAS_NP else "N/A"
|
|||||||
|
|
||||||
torch_version = torch.__version__ if HAS_TORCH else "N/A"
|
torch_version = torch.__version__ if HAS_TORCH else "N/A"
|
||||||
torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
|
torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
|
||||||
cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A"
|
cuda_version = torch.version.cuda if HAS_TORCH and torch.version.cuda is not None else "N/A"
|
||||||
|
|
||||||
|
|
||||||
# TODO(aliberts): refactor into an actual command `lerobot env`
|
# TODO(aliberts): refactor into an actual command `lerobot env`
|
||||||
|
|||||||
@@ -259,6 +259,10 @@ def eval_policy(
|
|||||||
threads = [] # for video saving threads
|
threads = [] # for video saving threads
|
||||||
n_episodes_rendered = 0 # for saving the correct number of videos
|
n_episodes_rendered = 0 # for saving the correct number of videos
|
||||||
|
|
||||||
|
video_paths: list[str] = [] # max_episodes_rendered > 0:
|
||||||
|
ep_frames: list[np.ndarray] = [] # max_episodes_rendered > 0
|
||||||
|
episode_data: dict | None = None # return_episode_data == True
|
||||||
|
|
||||||
# Callback for visualization.
|
# Callback for visualization.
|
||||||
def render_frame(env: gym.vector.VectorEnv):
|
def render_frame(env: gym.vector.VectorEnv):
|
||||||
# noqa: B023
|
# noqa: B023
|
||||||
@@ -271,19 +275,11 @@ def eval_policy(
|
|||||||
# Here we must render all frames and discard any we don't need.
|
# Here we must render all frames and discard any we don't need.
|
||||||
ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
|
ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
|
||||||
|
|
||||||
if max_episodes_rendered > 0:
|
|
||||||
video_paths: list[str] = []
|
|
||||||
|
|
||||||
if return_episode_data:
|
|
||||||
episode_data: dict | None = None
|
|
||||||
|
|
||||||
# we dont want progress bar when we use slurm, since it clutters the logs
|
# we dont want progress bar when we use slurm, since it clutters the logs
|
||||||
progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm())
|
progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm())
|
||||||
for batch_ix in progbar:
|
for batch_ix in progbar:
|
||||||
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
|
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
|
||||||
# step.
|
# step.
|
||||||
if max_episodes_rendered > 0:
|
|
||||||
ep_frames: list[np.ndarray] = []
|
|
||||||
|
|
||||||
if start_seed is None:
|
if start_seed is None:
|
||||||
seeds = None
|
seeds = None
|
||||||
@@ -320,13 +316,19 @@ def eval_policy(
|
|||||||
else:
|
else:
|
||||||
all_seeds.append(None)
|
all_seeds.append(None)
|
||||||
|
|
||||||
# FIXME: episode_data is either None or it doesn't exist
|
|
||||||
if return_episode_data:
|
if return_episode_data:
|
||||||
|
if episode_data is None:
|
||||||
|
start_data_index = 0
|
||||||
|
elif isinstance(episode_data, dict):
|
||||||
|
start_data_index = episode_data["index"][-1].item() + 1
|
||||||
|
else:
|
||||||
|
start_data_index = 0
|
||||||
|
|
||||||
this_episode_data = _compile_episode_data(
|
this_episode_data = _compile_episode_data(
|
||||||
rollout_data,
|
rollout_data,
|
||||||
done_indices,
|
done_indices,
|
||||||
start_episode_index=batch_ix * env.num_envs,
|
start_episode_index=batch_ix * env.num_envs,
|
||||||
start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)),
|
start_data_index=start_data_index,
|
||||||
fps=env.unwrapped.metadata["render_fps"],
|
fps=env.unwrapped.metadata["render_fps"],
|
||||||
)
|
)
|
||||||
if episode_data is None:
|
if episode_data is None:
|
||||||
@@ -453,6 +455,7 @@ def _compile_episode_data(
|
|||||||
return data_dict
|
return data_dict
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Steven): [WARN] Redefining built-in 'eval'
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
def eval_main(cfg: EvalPipelineConfig):
|
def eval_main(cfg: EvalPipelineConfig):
|
||||||
logging.info(pformat(asdict(cfg)))
|
logging.info(pformat(asdict(cfg)))
|
||||||
@@ -489,7 +492,7 @@ def eval_main(cfg: EvalPipelineConfig):
|
|||||||
print(info["aggregated"])
|
print(info["aggregated"])
|
||||||
|
|
||||||
# Save info
|
# Save info
|
||||||
with open(Path(cfg.output_dir) / "eval_info.json", "w") as f:
|
with open(Path(cfg.output_dir) / "eval_info.json", "w", encoding="utf-8") as f:
|
||||||
json.dump(info, f, indent=2)
|
json.dump(info, f, indent=2)
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ import torch
|
|||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
# TODO(Steven): #711 Broke this
|
||||||
from lerobot.common.datasets.compute_stats import compute_stats
|
from lerobot.common.datasets.compute_stats import compute_stats
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
|
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
|
||||||
@@ -89,7 +90,7 @@ def save_meta_data(
|
|||||||
|
|
||||||
# save info
|
# save info
|
||||||
info_path = meta_data_dir / "info.json"
|
info_path = meta_data_dir / "info.json"
|
||||||
with open(str(info_path), "w") as f:
|
with open(str(info_path), "w", encoding="utf-8") as f:
|
||||||
json.dump(info, f, indent=4)
|
json.dump(info, f, indent=4)
|
||||||
|
|
||||||
# save stats
|
# save stats
|
||||||
@@ -120,11 +121,11 @@ def push_dataset_card_to_hub(
|
|||||||
repo_id: str,
|
repo_id: str,
|
||||||
revision: str | None,
|
revision: str | None,
|
||||||
tags: list | None = None,
|
tags: list | None = None,
|
||||||
license: str = "apache-2.0",
|
dataset_license: str = "apache-2.0",
|
||||||
**card_kwargs,
|
**card_kwargs,
|
||||||
):
|
):
|
||||||
"""Creates and pushes a LeRobotDataset Card with appropriate tags to easily find it on the hub."""
|
"""Creates and pushes a LeRobotDataset Card with appropriate tags to easily find it on the hub."""
|
||||||
card = create_lerobot_dataset_card(tags=tags, license=license, **card_kwargs)
|
card = create_lerobot_dataset_card(tags=tags, license=dataset_license, **card_kwargs)
|
||||||
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=revision)
|
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=revision)
|
||||||
|
|
||||||
|
|
||||||
@@ -213,6 +214,7 @@ def push_dataset_to_hub(
|
|||||||
encoding,
|
encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO(Steven): This doesn't seem to exist, maybe it was removed/changed recently?
|
||||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
hf_dataset=hf_dataset,
|
hf_dataset=hf_dataset,
|
||||||
|
|||||||
@@ -155,12 +155,14 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
|
|
||||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||||
if cfg.env is not None:
|
if cfg.env is not None:
|
||||||
logging.info(f"{cfg.env.task=}")
|
logging.info("cfg.env.task=%s", cfg.env.task)
|
||||||
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
logging.info("cfg.steps=%s (%s)", cfg.steps, format_big_number(cfg.steps))
|
||||||
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
|
logging.info("dataset.num_frames=%s (%s)", dataset.num_frames, format_big_number(dataset.num_frames))
|
||||||
logging.info(f"{dataset.num_episodes=}")
|
logging.info("dataset.num_episodes=%s", dataset.num_episodes)
|
||||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
logging.info(
|
||||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
"num_learnable_params=%s (%s)", num_learnable_params, format_big_number(num_learnable_params)
|
||||||
|
)
|
||||||
|
logging.info("num_total_params=%s (%s)", num_total_params, format_big_number(num_total_params))
|
||||||
|
|
||||||
# create dataloader for offline training
|
# create dataloader for offline training
|
||||||
if hasattr(cfg.policy, "drop_n_last_frames"):
|
if hasattr(cfg.policy, "drop_n_last_frames"):
|
||||||
@@ -238,7 +240,7 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
train_tracker.reset_averages()
|
train_tracker.reset_averages()
|
||||||
|
|
||||||
if cfg.save_checkpoint and is_saving_step:
|
if cfg.save_checkpoint and is_saving_step:
|
||||||
logging.info(f"Checkpoint policy after step {step}")
|
logging.info("Checkpoint policy after step %s", step)
|
||||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
||||||
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
|
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
|
||||||
update_last_checkpoint(checkpoint_dir)
|
update_last_checkpoint(checkpoint_dir)
|
||||||
@@ -247,7 +249,7 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
|
|
||||||
if cfg.env and is_eval_step:
|
if cfg.env and is_eval_step:
|
||||||
step_id = get_step_identifier(step, cfg.steps)
|
step_id = get_step_identifier(step, cfg.steps)
|
||||||
logging.info(f"Eval policy at step {step}")
|
logging.info("Eval policy at step %s", step)
|
||||||
with (
|
with (
|
||||||
torch.no_grad(),
|
torch.no_grad(),
|
||||||
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
||||||
|
|||||||
@@ -150,7 +150,7 @@ def run_server(
|
|||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
dataset_version = (
|
dataset_version = (
|
||||||
str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
|
str(dataset.meta.version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
|
||||||
)
|
)
|
||||||
match = re.search(r"v(\d+)\.", dataset_version)
|
match = re.search(r"v(\d+)\.", dataset_version)
|
||||||
if match:
|
if match:
|
||||||
@@ -358,7 +358,7 @@ def visualize_dataset_html(
|
|||||||
if force_override:
|
if force_override:
|
||||||
shutil.rmtree(output_dir)
|
shutil.rmtree(output_dir)
|
||||||
else:
|
else:
|
||||||
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
|
logging.info("Output directory already exists. Loading from it: '%s'", {output_dir})
|
||||||
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|||||||
68
tests/fixtures/dataset_factories.py
vendored
68
tests/fixtures/dataset_factories.py
vendored
@@ -52,16 +52,16 @@ def get_task_index(task_dicts: dict, task: str) -> int:
|
|||||||
return task_to_task_index[task]
|
return task_to_task_index[task]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(name="img_tensor_factory", scope="session")
|
||||||
def img_tensor_factory():
|
def fixture_img_tensor_factory():
|
||||||
def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor:
|
def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor:
|
||||||
return torch.rand((channels, height, width), dtype=dtype)
|
return torch.rand((channels, height, width), dtype=dtype)
|
||||||
|
|
||||||
return _create_img_tensor
|
return _create_img_tensor
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(name="img_array_factory", scope="session")
|
||||||
def img_array_factory():
|
def fixture_img_array_factory():
|
||||||
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray:
|
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray:
|
||||||
if np.issubdtype(dtype, np.unsignedinteger):
|
if np.issubdtype(dtype, np.unsignedinteger):
|
||||||
# Int array in [0, 255] range
|
# Int array in [0, 255] range
|
||||||
@@ -76,8 +76,8 @@ def img_array_factory():
|
|||||||
return _create_img_array
|
return _create_img_array
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(name="img_factory", scope="session")
|
||||||
def img_factory(img_array_factory):
|
def fixture_img_factory(img_array_factory):
|
||||||
def _create_img(height=100, width=100) -> PIL.Image.Image:
|
def _create_img(height=100, width=100) -> PIL.Image.Image:
|
||||||
img_array = img_array_factory(height=height, width=width)
|
img_array = img_array_factory(height=height, width=width)
|
||||||
return PIL.Image.fromarray(img_array)
|
return PIL.Image.fromarray(img_array)
|
||||||
@@ -85,13 +85,17 @@ def img_factory(img_array_factory):
|
|||||||
return _create_img
|
return _create_img
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(name="features_factory", scope="session")
|
||||||
def features_factory():
|
def fixture_features_factory():
|
||||||
def _create_features(
|
def _create_features(
|
||||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
motor_features: dict | None = None,
|
||||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
camera_features: dict | None = None,
|
||||||
use_videos: bool = True,
|
use_videos: bool = True,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
if motor_features is None:
|
||||||
|
motor_features = DUMMY_MOTOR_FEATURES
|
||||||
|
if camera_features is None:
|
||||||
|
camera_features = DUMMY_CAMERA_FEATURES
|
||||||
if use_videos:
|
if use_videos:
|
||||||
camera_ft = {
|
camera_ft = {
|
||||||
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items()
|
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items()
|
||||||
@@ -107,8 +111,8 @@ def features_factory():
|
|||||||
return _create_features
|
return _create_features
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(name="info_factory", scope="session")
|
||||||
def info_factory(features_factory):
|
def fixture_info_factory(features_factory):
|
||||||
def _create_info(
|
def _create_info(
|
||||||
codebase_version: str = CODEBASE_VERSION,
|
codebase_version: str = CODEBASE_VERSION,
|
||||||
fps: int = DEFAULT_FPS,
|
fps: int = DEFAULT_FPS,
|
||||||
@@ -121,10 +125,14 @@ def info_factory(features_factory):
|
|||||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||||
data_path: str = DEFAULT_PARQUET_PATH,
|
data_path: str = DEFAULT_PARQUET_PATH,
|
||||||
video_path: str = DEFAULT_VIDEO_PATH,
|
video_path: str = DEFAULT_VIDEO_PATH,
|
||||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
motor_features: dict | None = None,
|
||||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
camera_features: dict | None = None,
|
||||||
use_videos: bool = True,
|
use_videos: bool = True,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
if motor_features is None:
|
||||||
|
motor_features = DUMMY_MOTOR_FEATURES
|
||||||
|
if camera_features is None:
|
||||||
|
camera_features = DUMMY_CAMERA_FEATURES
|
||||||
features = features_factory(motor_features, camera_features, use_videos)
|
features = features_factory(motor_features, camera_features, use_videos)
|
||||||
return {
|
return {
|
||||||
"codebase_version": codebase_version,
|
"codebase_version": codebase_version,
|
||||||
@@ -145,8 +153,8 @@ def info_factory(features_factory):
|
|||||||
return _create_info
|
return _create_info
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(name="stats_factory", scope="session")
|
||||||
def stats_factory():
|
def fixture_stats_factory():
|
||||||
def _create_stats(
|
def _create_stats(
|
||||||
features: dict[str] | None = None,
|
features: dict[str] | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
@@ -175,8 +183,8 @@ def stats_factory():
|
|||||||
return _create_stats
|
return _create_stats
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(name="episodes_stats_factory", scope="session")
|
||||||
def episodes_stats_factory(stats_factory):
|
def fixture_episodes_stats_factory(stats_factory):
|
||||||
def _create_episodes_stats(
|
def _create_episodes_stats(
|
||||||
features: dict[str],
|
features: dict[str],
|
||||||
total_episodes: int = 3,
|
total_episodes: int = 3,
|
||||||
@@ -192,8 +200,8 @@ def episodes_stats_factory(stats_factory):
|
|||||||
return _create_episodes_stats
|
return _create_episodes_stats
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(name="tasks_factory", scope="session")
|
||||||
def tasks_factory():
|
def fixture_tasks_factory():
|
||||||
def _create_tasks(total_tasks: int = 3) -> int:
|
def _create_tasks(total_tasks: int = 3) -> int:
|
||||||
tasks = {}
|
tasks = {}
|
||||||
for task_index in range(total_tasks):
|
for task_index in range(total_tasks):
|
||||||
@@ -204,8 +212,8 @@ def tasks_factory():
|
|||||||
return _create_tasks
|
return _create_tasks
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(name="episodes_factory", scope="session")
|
||||||
def episodes_factory(tasks_factory):
|
def fixture_episodes_factory(tasks_factory):
|
||||||
def _create_episodes(
|
def _create_episodes(
|
||||||
total_episodes: int = 3,
|
total_episodes: int = 3,
|
||||||
total_frames: int = 400,
|
total_frames: int = 400,
|
||||||
@@ -252,8 +260,8 @@ def episodes_factory(tasks_factory):
|
|||||||
return _create_episodes
|
return _create_episodes
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(name="hf_dataset_factory", scope="session")
|
||||||
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
|
def fixture_hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
|
||||||
def _create_hf_dataset(
|
def _create_hf_dataset(
|
||||||
features: dict | None = None,
|
features: dict | None = None,
|
||||||
tasks: list[dict] | None = None,
|
tasks: list[dict] | None = None,
|
||||||
@@ -310,8 +318,8 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
|||||||
return _create_hf_dataset
|
return _create_hf_dataset
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(name="lerobot_dataset_metadata_factory", scope="session")
|
||||||
def lerobot_dataset_metadata_factory(
|
def fixture_lerobot_dataset_metadata_factory(
|
||||||
info_factory,
|
info_factory,
|
||||||
stats_factory,
|
stats_factory,
|
||||||
episodes_stats_factory,
|
episodes_stats_factory,
|
||||||
@@ -364,8 +372,8 @@ def lerobot_dataset_metadata_factory(
|
|||||||
return _create_lerobot_dataset_metadata
|
return _create_lerobot_dataset_metadata
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(name="lerobot_dataset_factory", scope="session")
|
||||||
def lerobot_dataset_factory(
|
def fixture_lerobot_dataset_factory(
|
||||||
info_factory,
|
info_factory,
|
||||||
stats_factory,
|
stats_factory,
|
||||||
episodes_stats_factory,
|
episodes_stats_factory,
|
||||||
@@ -443,6 +451,6 @@ def lerobot_dataset_factory(
|
|||||||
return _create_lerobot_dataset
|
return _create_lerobot_dataset
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(name="empty_lerobot_dataset_factory", scope="session")
|
||||||
def empty_lerobot_dataset_factory() -> LeRobotDatasetFactory:
|
def fixture_empty_lerobot_dataset_factory() -> LeRobotDatasetFactory:
|
||||||
return partial(LeRobotDataset.create, repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS)
|
return partial(LeRobotDataset.create, repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS)
|
||||||
|
|||||||
34
tests/fixtures/files.py
vendored
34
tests/fixtures/files.py
vendored
@@ -31,12 +31,12 @@ from lerobot.common.datasets.utils import (
|
|||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def info_path(info_factory):
|
def info_path(info_factory):
|
||||||
def _create_info_json_file(dir: Path, info: dict | None = None) -> Path:
|
def _create_info_json_file(input_dir: Path, info: dict | None = None) -> Path:
|
||||||
if not info:
|
if not info:
|
||||||
info = info_factory()
|
info = info_factory()
|
||||||
fpath = dir / INFO_PATH
|
fpath = input_dir / INFO_PATH
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
with open(fpath, "w") as f:
|
with open(fpath, "w", encoding="utf-8") as f:
|
||||||
json.dump(info, f, indent=4, ensure_ascii=False)
|
json.dump(info, f, indent=4, ensure_ascii=False)
|
||||||
return fpath
|
return fpath
|
||||||
|
|
||||||
@@ -45,12 +45,12 @@ def info_path(info_factory):
|
|||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def stats_path(stats_factory):
|
def stats_path(stats_factory):
|
||||||
def _create_stats_json_file(dir: Path, stats: dict | None = None) -> Path:
|
def _create_stats_json_file(input_dir: Path, stats: dict | None = None) -> Path:
|
||||||
if not stats:
|
if not stats:
|
||||||
stats = stats_factory()
|
stats = stats_factory()
|
||||||
fpath = dir / STATS_PATH
|
fpath = input_dir / STATS_PATH
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
with open(fpath, "w") as f:
|
with open(fpath, "w", encoding="utf-8") as f:
|
||||||
json.dump(stats, f, indent=4, ensure_ascii=False)
|
json.dump(stats, f, indent=4, ensure_ascii=False)
|
||||||
return fpath
|
return fpath
|
||||||
|
|
||||||
@@ -59,10 +59,10 @@ def stats_path(stats_factory):
|
|||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def episodes_stats_path(episodes_stats_factory):
|
def episodes_stats_path(episodes_stats_factory):
|
||||||
def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path:
|
def _create_episodes_stats_jsonl_file(input_dir: Path, episodes_stats: list[dict] | None = None) -> Path:
|
||||||
if not episodes_stats:
|
if not episodes_stats:
|
||||||
episodes_stats = episodes_stats_factory()
|
episodes_stats = episodes_stats_factory()
|
||||||
fpath = dir / EPISODES_STATS_PATH
|
fpath = input_dir / EPISODES_STATS_PATH
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
with jsonlines.open(fpath, "w") as writer:
|
with jsonlines.open(fpath, "w") as writer:
|
||||||
writer.write_all(episodes_stats.values())
|
writer.write_all(episodes_stats.values())
|
||||||
@@ -73,10 +73,10 @@ def episodes_stats_path(episodes_stats_factory):
|
|||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def tasks_path(tasks_factory):
|
def tasks_path(tasks_factory):
|
||||||
def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path:
|
def _create_tasks_jsonl_file(input_dir: Path, tasks: list | None = None) -> Path:
|
||||||
if not tasks:
|
if not tasks:
|
||||||
tasks = tasks_factory()
|
tasks = tasks_factory()
|
||||||
fpath = dir / TASKS_PATH
|
fpath = input_dir / TASKS_PATH
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
with jsonlines.open(fpath, "w") as writer:
|
with jsonlines.open(fpath, "w") as writer:
|
||||||
writer.write_all(tasks.values())
|
writer.write_all(tasks.values())
|
||||||
@@ -87,10 +87,10 @@ def tasks_path(tasks_factory):
|
|||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def episode_path(episodes_factory):
|
def episode_path(episodes_factory):
|
||||||
def _create_episodes_jsonl_file(dir: Path, episodes: list | None = None) -> Path:
|
def _create_episodes_jsonl_file(input_dir: Path, episodes: list | None = None) -> Path:
|
||||||
if not episodes:
|
if not episodes:
|
||||||
episodes = episodes_factory()
|
episodes = episodes_factory()
|
||||||
fpath = dir / EPISODES_PATH
|
fpath = input_dir / EPISODES_PATH
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
with jsonlines.open(fpath, "w") as writer:
|
with jsonlines.open(fpath, "w") as writer:
|
||||||
writer.write_all(episodes.values())
|
writer.write_all(episodes.values())
|
||||||
@@ -102,7 +102,7 @@ def episode_path(episodes_factory):
|
|||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def single_episode_parquet_path(hf_dataset_factory, info_factory):
|
def single_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||||
def _create_single_episode_parquet(
|
def _create_single_episode_parquet(
|
||||||
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
input_dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||||
) -> Path:
|
) -> Path:
|
||||||
if not info:
|
if not info:
|
||||||
info = info_factory()
|
info = info_factory()
|
||||||
@@ -112,7 +112,7 @@ def single_episode_parquet_path(hf_dataset_factory, info_factory):
|
|||||||
data_path = info["data_path"]
|
data_path = info["data_path"]
|
||||||
chunks_size = info["chunks_size"]
|
chunks_size = info["chunks_size"]
|
||||||
ep_chunk = ep_idx // chunks_size
|
ep_chunk = ep_idx // chunks_size
|
||||||
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
fpath = input_dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
table = hf_dataset.data.table
|
table = hf_dataset.data.table
|
||||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||||
@@ -125,7 +125,7 @@ def single_episode_parquet_path(hf_dataset_factory, info_factory):
|
|||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def multi_episode_parquet_path(hf_dataset_factory, info_factory):
|
def multi_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||||
def _create_multi_episode_parquet(
|
def _create_multi_episode_parquet(
|
||||||
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
input_dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||||
) -> Path:
|
) -> Path:
|
||||||
if not info:
|
if not info:
|
||||||
info = info_factory()
|
info = info_factory()
|
||||||
@@ -137,11 +137,11 @@ def multi_episode_parquet_path(hf_dataset_factory, info_factory):
|
|||||||
total_episodes = info["total_episodes"]
|
total_episodes = info["total_episodes"]
|
||||||
for ep_idx in range(total_episodes):
|
for ep_idx in range(total_episodes):
|
||||||
ep_chunk = ep_idx // chunks_size
|
ep_chunk = ep_idx // chunks_size
|
||||||
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
fpath = input_dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
table = hf_dataset.data.table
|
table = hf_dataset.data.table
|
||||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||||
pq.write_table(ep_table, fpath)
|
pq.write_table(ep_table, fpath)
|
||||||
return dir / "data"
|
return input_dir / "data"
|
||||||
|
|
||||||
return _create_multi_episode_parquet
|
return _create_multi_episode_parquet
|
||||||
|
|||||||
6
tests/fixtures/hub.py
vendored
6
tests/fixtures/hub.py
vendored
@@ -81,12 +81,12 @@ def mock_snapshot_download_factory(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def _mock_snapshot_download(
|
def _mock_snapshot_download(
|
||||||
repo_id: str,
|
_repo_id: str,
|
||||||
|
*_args,
|
||||||
local_dir: str | Path | None = None,
|
local_dir: str | Path | None = None,
|
||||||
allow_patterns: str | list[str] | None = None,
|
allow_patterns: str | list[str] | None = None,
|
||||||
ignore_patterns: str | list[str] | None = None,
|
ignore_patterns: str | list[str] | None = None,
|
||||||
*args,
|
**_kwargs,
|
||||||
**kwargs,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
if not local_dir:
|
if not local_dir:
|
||||||
local_dir = LEROBOT_TEST_DIR
|
local_dir = LEROBOT_TEST_DIR
|
||||||
|
|||||||
12
tests/fixtures/optimizers.py
vendored
12
tests/fixtures/optimizers.py
vendored
@@ -18,13 +18,13 @@ from lerobot.common.optim.optimizers import AdamConfig
|
|||||||
from lerobot.common.optim.schedulers import VQBeTSchedulerConfig
|
from lerobot.common.optim.schedulers import VQBeTSchedulerConfig
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(name="model_params")
|
||||||
def model_params():
|
def fixture_model_params():
|
||||||
return [torch.nn.Parameter(torch.randn(10, 10))]
|
return [torch.nn.Parameter(torch.randn(10, 10))]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(name="optimizer")
|
||||||
def optimizer(model_params):
|
def fixture_optimizer(model_params):
|
||||||
optimizer = AdamConfig().build(model_params)
|
optimizer = AdamConfig().build(model_params)
|
||||||
# Dummy step to populate state
|
# Dummy step to populate state
|
||||||
loss = sum(param.sum() for param in model_params)
|
loss = sum(param.sum() for param in model_params)
|
||||||
@@ -33,7 +33,7 @@ def optimizer(model_params):
|
|||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(name="scheduler")
|
||||||
def scheduler(optimizer):
|
def fixture_scheduler(optimizer):
|
||||||
config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5)
|
config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5)
|
||||||
return config.build(optimizer, num_training_steps=100)
|
return config.build(optimizer, num_training_steps=100)
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ from lerobot.common.datasets.utils import create_lerobot_dataset_card
|
|||||||
def test_default_parameters():
|
def test_default_parameters():
|
||||||
card = create_lerobot_dataset_card()
|
card = create_lerobot_dataset_card()
|
||||||
assert isinstance(card, DatasetCard)
|
assert isinstance(card, DatasetCard)
|
||||||
|
# TODO(Steven): Base class CardDate should have 'tags' as a member if we want RepoCard to hold a reference to this abstraction
|
||||||
|
# card.data gives a CardDate type, implementations of this class do have 'tags' but the base class doesn't
|
||||||
assert card.data.tags == ["LeRobot"]
|
assert card.data.tags == ["LeRobot"]
|
||||||
assert card.data.task_categories == ["robotics"]
|
assert card.data.task_categories == ["robotics"]
|
||||||
assert card.data.configs == [
|
assert card.data.configs == [
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ def rotate(color_image, rotation):
|
|||||||
|
|
||||||
|
|
||||||
class VideoCapture:
|
class VideoCapture:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *_args, **_kwargs):
|
||||||
self._mock_dict = {
|
self._mock_dict = {
|
||||||
CAP_PROP_FPS: 30,
|
CAP_PROP_FPS: 30,
|
||||||
CAP_PROP_FRAME_WIDTH: 640,
|
CAP_PROP_FRAME_WIDTH: 640,
|
||||||
|
|||||||
@@ -24,10 +24,9 @@ DEFAULT_BAUDRATE = 9_600
|
|||||||
COMM_SUCCESS = 0 # tx or rx packet communication success
|
COMM_SUCCESS = 0 # tx or rx packet communication success
|
||||||
|
|
||||||
|
|
||||||
def convert_to_bytes(value, bytes):
|
def convert_to_bytes(value, _byte):
|
||||||
# TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform
|
# TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform
|
||||||
# `convert_bytes_to_value`
|
# `convert_bytes_to_value`
|
||||||
del bytes # unused
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
@@ -74,7 +73,7 @@ class PacketHandler:
|
|||||||
|
|
||||||
|
|
||||||
class GroupSyncRead:
|
class GroupSyncRead:
|
||||||
def __init__(self, port_handler, packet_handler, address, bytes):
|
def __init__(self, _port_handler, packet_handler, _address, _byte):
|
||||||
self.packet_handler = packet_handler
|
self.packet_handler = packet_handler
|
||||||
|
|
||||||
def addParam(self, motor_index): # noqa: N802
|
def addParam(self, motor_index): # noqa: N802
|
||||||
@@ -85,12 +84,12 @@ class GroupSyncRead:
|
|||||||
def txRxPacket(self): # noqa: N802
|
def txRxPacket(self): # noqa: N802
|
||||||
return COMM_SUCCESS
|
return COMM_SUCCESS
|
||||||
|
|
||||||
def getData(self, index, address, bytes): # noqa: N802
|
def getData(self, index, address, _byte): # noqa: N802
|
||||||
return self.packet_handler.data[index][address]
|
return self.packet_handler.data[index][address]
|
||||||
|
|
||||||
|
|
||||||
class GroupSyncWrite:
|
class GroupSyncWrite:
|
||||||
def __init__(self, port_handler, packet_handler, address, bytes):
|
def __init__(self, _port_handler, packet_handler, address, _byte):
|
||||||
self.packet_handler = packet_handler
|
self.packet_handler = packet_handler
|
||||||
self.address = address
|
self.address = address
|
||||||
|
|
||||||
|
|||||||
@@ -27,6 +27,13 @@ class format(enum.Enum): # noqa: N801
|
|||||||
|
|
||||||
|
|
||||||
class config: # noqa: N801
|
class config: # noqa: N801
|
||||||
|
device_enabled = None
|
||||||
|
stream_type = None
|
||||||
|
width = None
|
||||||
|
height = None
|
||||||
|
color_format = None
|
||||||
|
fps = None
|
||||||
|
|
||||||
def enable_device(self, device_id: str):
|
def enable_device(self, device_id: str):
|
||||||
self.device_enabled = device_id
|
self.device_enabled = device_id
|
||||||
|
|
||||||
@@ -125,8 +132,7 @@ class RSDevice:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_info(self, camera_info) -> str:
|
def get_info(self, _camera_info) -> str:
|
||||||
del camera_info # unused
|
|
||||||
# return fake serial number
|
# return fake serial number
|
||||||
return "123456789"
|
return "123456789"
|
||||||
|
|
||||||
@@ -145,4 +151,3 @@ class camera_info: # noqa: N801
|
|||||||
|
|
||||||
def __init__(self, serial_number):
|
def __init__(self, serial_number):
|
||||||
del serial_number
|
del serial_number
|
||||||
pass
|
|
||||||
|
|||||||
@@ -24,10 +24,10 @@ DEFAULT_BAUDRATE = 1_000_000
|
|||||||
COMM_SUCCESS = 0 # tx or rx packet communication success
|
COMM_SUCCESS = 0 # tx or rx packet communication success
|
||||||
|
|
||||||
|
|
||||||
def convert_to_bytes(value, bytes):
|
def convert_to_bytes(value, byte):
|
||||||
# TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform
|
# TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform
|
||||||
# `convert_bytes_to_value`
|
# `convert_bytes_to_value`
|
||||||
del bytes # unused
|
del byte # unused
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
@@ -85,7 +85,7 @@ class PacketHandler:
|
|||||||
|
|
||||||
|
|
||||||
class GroupSyncRead:
|
class GroupSyncRead:
|
||||||
def __init__(self, port_handler, packet_handler, address, bytes):
|
def __init__(self, _port_handler, packet_handler, _address, _byte):
|
||||||
self.packet_handler = packet_handler
|
self.packet_handler = packet_handler
|
||||||
|
|
||||||
def addParam(self, motor_index): # noqa: N802
|
def addParam(self, motor_index): # noqa: N802
|
||||||
@@ -96,12 +96,12 @@ class GroupSyncRead:
|
|||||||
def txRxPacket(self): # noqa: N802
|
def txRxPacket(self): # noqa: N802
|
||||||
return COMM_SUCCESS
|
return COMM_SUCCESS
|
||||||
|
|
||||||
def getData(self, index, address, bytes): # noqa: N802
|
def getData(self, index, address, _byte): # noqa: N802
|
||||||
return self.packet_handler.data[index][address]
|
return self.packet_handler.data[index][address]
|
||||||
|
|
||||||
|
|
||||||
class GroupSyncWrite:
|
class GroupSyncWrite:
|
||||||
def __init__(self, port_handler, packet_handler, address, bytes):
|
def __init__(self, _port_handler, packet_handler, address, _byte):
|
||||||
self.packet_handler = packet_handler
|
self.packet_handler = packet_handler
|
||||||
self.address = address
|
self.address = address
|
||||||
|
|
||||||
|
|||||||
@@ -81,11 +81,11 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
for dataset in [
|
for available_dataset in [
|
||||||
"lerobot/pusht",
|
"lerobot/pusht",
|
||||||
"lerobot/aloha_sim_insertion_human",
|
"lerobot/aloha_sim_insertion_human",
|
||||||
"lerobot/xarm_lift_medium",
|
"lerobot/xarm_lift_medium",
|
||||||
"lerobot/nyu_franka_play_dataset",
|
"lerobot/nyu_franka_play_dataset",
|
||||||
"lerobot/cmu_stretch",
|
"lerobot/cmu_stretch",
|
||||||
]:
|
]:
|
||||||
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors", repo_id=dataset)
|
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors", repo_id=available_dataset)
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ def save_single_transforms(original_frame: torch.Tensor, output_dir: Path):
|
|||||||
}
|
}
|
||||||
|
|
||||||
frames = {"original_frame": original_frame}
|
frames = {"original_frame": original_frame}
|
||||||
for tf_type, tf_name, min_max_values in transforms.items():
|
for tf_type, tf_name, min_max_values in transforms:
|
||||||
for min_max in min_max_values:
|
for min_max in min_max_values:
|
||||||
tf_cfg = ImageTransformConfig(type=tf_type, kwargs={tf_name: min_max})
|
tf_cfg = ImageTransformConfig(type=tf_type, kwargs={tf_name: min_max})
|
||||||
tf = make_transform_from_config(tf_cfg)
|
tf = make_transform_from_config(tf_cfg)
|
||||||
|
|||||||
@@ -150,6 +150,7 @@ def test_camera(request, camera_type, mock):
|
|||||||
else:
|
else:
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
|
manual_rot_img: np.ndarray = None
|
||||||
if rotation is None:
|
if rotation is None:
|
||||||
manual_rot_img = ori_color_image
|
manual_rot_img = ori_color_image
|
||||||
assert camera.rotation is None
|
assert camera.rotation is None
|
||||||
@@ -197,10 +198,14 @@ def test_camera(request, camera_type, mock):
|
|||||||
@require_camera
|
@require_camera
|
||||||
def test_save_images_from_cameras(tmp_path, request, camera_type, mock):
|
def test_save_images_from_cameras(tmp_path, request, camera_type, mock):
|
||||||
# TODO(rcadene): refactor
|
# TODO(rcadene): refactor
|
||||||
|
save_images_from_cameras = None
|
||||||
|
|
||||||
if camera_type == "opencv":
|
if camera_type == "opencv":
|
||||||
from lerobot.common.robot_devices.cameras.opencv import save_images_from_cameras
|
from lerobot.common.robot_devices.cameras.opencv import save_images_from_cameras
|
||||||
elif camera_type == "intelrealsense":
|
elif camera_type == "intelrealsense":
|
||||||
from lerobot.common.robot_devices.cameras.intelrealsense import save_images_from_cameras
|
from lerobot.common.robot_devices.cameras.intelrealsense import save_images_from_cameras
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported camera type: {camera_type}")
|
||||||
|
|
||||||
# Small `record_time_s` to speedup unit tests
|
# Small `record_time_s` to speedup unit tests
|
||||||
save_images_from_cameras(tmp_path, record_time_s=0.02, mock=mock)
|
save_images_from_cameras(tmp_path, record_time_s=0.02, mock=mock)
|
||||||
|
|||||||
@@ -30,12 +30,12 @@ from lerobot.common.datasets.compute_stats import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def mock_load_image_as_numpy(path, dtype, channel_first):
|
def mock_load_image_as_numpy(_path, dtype, channel_first):
|
||||||
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
|
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(name="sample_array")
|
||||||
def sample_array():
|
def fixture_sample_array():
|
||||||
return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||||
|
|
||||||
|
|
||||||
@@ -62,7 +62,7 @@ def test_sample_indices():
|
|||||||
|
|
||||||
|
|
||||||
@patch("lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy)
|
@patch("lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy)
|
||||||
def test_sample_images(mock_load):
|
def test_sample_images(_mock_load):
|
||||||
image_paths = [f"image_{i}.jpg" for i in range(100)]
|
image_paths = [f"image_{i}.jpg" for i in range(100)]
|
||||||
images = sample_images(image_paths)
|
images = sample_images(image_paths)
|
||||||
assert isinstance(images, np.ndarray)
|
assert isinstance(images, np.ndarray)
|
||||||
|
|||||||
@@ -48,8 +48,8 @@ from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
|||||||
from tests.utils import require_x86_64_kernel
|
from tests.utils import require_x86_64_kernel
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(name="image_dataset")
|
||||||
def image_dataset(tmp_path, empty_lerobot_dataset_factory):
|
def fixture_image_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {
|
features = {
|
||||||
"image": {
|
"image": {
|
||||||
"dtype": "image",
|
"dtype": "image",
|
||||||
@@ -374,7 +374,7 @@ def test_factory(env_name, repo_id, policy_name):
|
|||||||
if required:
|
if required:
|
||||||
assert key in item, f"{key}"
|
assert key in item, f"{key}"
|
||||||
else:
|
else:
|
||||||
logging.warning(f'Missing key in dataset: "{key}" not in {dataset}.')
|
logging.warning('Missing key in dataset: "%s" not in %s.', key, dataset)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if delta_timestamps is not None and key in delta_timestamps:
|
if delta_timestamps is not None and key in delta_timestamps:
|
||||||
|
|||||||
@@ -42,7 +42,9 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.n
|
|||||||
table = hf_dataset.data.table
|
table = hf_dataset.data.table
|
||||||
total_episodes = calculate_total_episode(hf_dataset)
|
total_episodes = calculate_total_episode(hf_dataset)
|
||||||
for ep_idx in range(total_episodes):
|
for ep_idx in range(total_episodes):
|
||||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
ep_table = table.filter(
|
||||||
|
pc.equal(table["episode_index"], ep_idx)
|
||||||
|
) # TODO(Steven): What is this check supposed to do?
|
||||||
episode_lengths.insert(ep_idx, len(ep_table))
|
episode_lengths.insert(ep_idx, len(ep_table))
|
||||||
|
|
||||||
cumulative_lengths = list(accumulate(episode_lengths))
|
cumulative_lengths = list(accumulate(episode_lengths))
|
||||||
@@ -52,8 +54,8 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.n
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(name="synced_timestamps_factory", scope="module")
|
||||||
def synced_timestamps_factory(hf_dataset_factory):
|
def fixture_synced_timestamps_factory(hf_dataset_factory):
|
||||||
def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
hf_dataset = hf_dataset_factory(fps=fps)
|
hf_dataset = hf_dataset_factory(fps=fps)
|
||||||
timestamps = torch.stack(hf_dataset["timestamp"]).numpy()
|
timestamps = torch.stack(hf_dataset["timestamp"]).numpy()
|
||||||
@@ -64,8 +66,8 @@ def synced_timestamps_factory(hf_dataset_factory):
|
|||||||
return _create_synced_timestamps
|
return _create_synced_timestamps
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(name="unsynced_timestamps_factory", scope="module")
|
||||||
def unsynced_timestamps_factory(synced_timestamps_factory):
|
def fixture_unsynced_timestamps_factory(synced_timestamps_factory):
|
||||||
def _create_unsynced_timestamps(
|
def _create_unsynced_timestamps(
|
||||||
fps: int = 30, tolerance_s: float = 1e-4
|
fps: int = 30, tolerance_s: float = 1e-4
|
||||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
@@ -76,8 +78,8 @@ def unsynced_timestamps_factory(synced_timestamps_factory):
|
|||||||
return _create_unsynced_timestamps
|
return _create_unsynced_timestamps
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(name="slightly_off_timestamps_factory", scope="module")
|
||||||
def slightly_off_timestamps_factory(synced_timestamps_factory):
|
def fixture_slightly_off_timestamps_factory(synced_timestamps_factory):
|
||||||
def _create_slightly_off_timestamps(
|
def _create_slightly_off_timestamps(
|
||||||
fps: int = 30, tolerance_s: float = 1e-4
|
fps: int = 30, tolerance_s: float = 1e-4
|
||||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
@@ -88,22 +90,26 @@ def slightly_off_timestamps_factory(synced_timestamps_factory):
|
|||||||
return _create_slightly_off_timestamps
|
return _create_slightly_off_timestamps
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(name="valid_delta_timestamps_factory", scope="module")
|
||||||
def valid_delta_timestamps_factory():
|
def fixture_valid_delta_timestamps_factory():
|
||||||
def _create_valid_delta_timestamps(
|
def _create_valid_delta_timestamps(
|
||||||
fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)
|
fps: int = 30, keys: list | None = None, min_max_range: tuple[int, int] = (-10, 10)
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
if keys is None:
|
||||||
|
keys = DUMMY_MOTOR_FEATURES
|
||||||
delta_timestamps = {key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys}
|
delta_timestamps = {key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys}
|
||||||
return delta_timestamps
|
return delta_timestamps
|
||||||
|
|
||||||
return _create_valid_delta_timestamps
|
return _create_valid_delta_timestamps
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(name="invalid_delta_timestamps_factory", scope="module")
|
||||||
def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
def fixture_invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||||
def _create_invalid_delta_timestamps(
|
def _create_invalid_delta_timestamps(
|
||||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
|
fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
if keys is None:
|
||||||
|
keys = DUMMY_MOTOR_FEATURES
|
||||||
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
||||||
# Modify a single timestamp just outside tolerance
|
# Modify a single timestamp just outside tolerance
|
||||||
for key in keys:
|
for key in keys:
|
||||||
@@ -113,11 +119,13 @@ def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
|||||||
return _create_invalid_delta_timestamps
|
return _create_invalid_delta_timestamps
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(name="slightly_off_delta_timestamps_factory", scope="module")
|
||||||
def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
def fixture_slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||||
def _create_slightly_off_delta_timestamps(
|
def _create_slightly_off_delta_timestamps(
|
||||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
|
fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
if keys is None:
|
||||||
|
keys = DUMMY_MOTOR_FEATURES
|
||||||
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
||||||
# Modify a single timestamp just inside tolerance
|
# Modify a single timestamp just inside tolerance
|
||||||
for key in delta_timestamps:
|
for key in delta_timestamps:
|
||||||
@@ -128,9 +136,11 @@ def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
|||||||
return _create_slightly_off_delta_timestamps
|
return _create_slightly_off_delta_timestamps
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(name="delta_indices_factory", scope="module")
|
||||||
def delta_indices_factory():
|
def fixture_delta_indices_factory():
|
||||||
def _delta_indices(keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)) -> dict:
|
def _delta_indices(keys: list | None = None, min_max_range: tuple[int, int] = (-10, 10)) -> dict:
|
||||||
|
if keys is None:
|
||||||
|
keys = DUMMY_MOTOR_FEATURES
|
||||||
return {key: list(range(*min_max_range)) for key in keys}
|
return {key: list(range(*min_max_range)) for key in keys}
|
||||||
|
|
||||||
return _delta_indices
|
return _delta_indices
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ def _run_script(path):
|
|||||||
|
|
||||||
|
|
||||||
def _read_file(path):
|
def _read_file(path):
|
||||||
with open(path) as file:
|
with open(path, encoding="utf-8") as file:
|
||||||
return file.read()
|
return file.read()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -37,8 +37,8 @@ from tests.scripts.save_image_transforms_to_safetensors import ARTIFACT_DIR
|
|||||||
from tests.utils import require_x86_64_kernel
|
from tests.utils import require_x86_64_kernel
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(name="color_jitters")
|
||||||
def color_jitters():
|
def fixture_color_jitters():
|
||||||
return [
|
return [
|
||||||
v2.ColorJitter(brightness=0.5),
|
v2.ColorJitter(brightness=0.5),
|
||||||
v2.ColorJitter(contrast=0.5),
|
v2.ColorJitter(contrast=0.5),
|
||||||
@@ -46,18 +46,18 @@ def color_jitters():
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(name="single_transforms")
|
||||||
def single_transforms():
|
def fixture_single_transforms():
|
||||||
return load_file(ARTIFACT_DIR / "single_transforms.safetensors")
|
return load_file(ARTIFACT_DIR / "single_transforms.safetensors")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(name="img_tensor")
|
||||||
def img_tensor(single_transforms):
|
def fixture_img_tensor(single_transforms):
|
||||||
return single_transforms["original_frame"]
|
return single_transforms["original_frame"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(name="default_transforms")
|
||||||
def default_transforms():
|
def fixture_default_transforms():
|
||||||
return load_file(ARTIFACT_DIR / "default_transforms.safetensors")
|
return load_file(ARTIFACT_DIR / "default_transforms.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ import pytest
|
|||||||
from lerobot.common.utils.io_utils import deserialize_json_into_object
|
from lerobot.common.utils.io_utils import deserialize_json_into_object
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(name="tmp_json_file")
|
||||||
def tmp_json_file(tmp_path: Path):
|
def fixture_tmp_json_file(tmp_path: Path):
|
||||||
"""Writes `data` to a temporary JSON file and returns the file's path."""
|
"""Writes `data` to a temporary JSON file and returns the file's path."""
|
||||||
|
|
||||||
def _write(data: Any) -> Path:
|
def _write(data: Any) -> Path:
|
||||||
|
|||||||
@@ -16,8 +16,8 @@ import pytest
|
|||||||
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
|
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(name="mock_metrics")
|
||||||
def mock_metrics():
|
def fixture_mock_metrics():
|
||||||
return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")}
|
return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")}
|
||||||
|
|
||||||
|
|
||||||
@@ -87,6 +87,7 @@ def test_metrics_tracker_getattr(mock_metrics):
|
|||||||
_ = tracker.non_existent_metric
|
_ = tracker.non_existent_metric
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Steven): I don't understand what's supposed to happen here
|
||||||
def test_metrics_tracker_setattr(mock_metrics):
|
def test_metrics_tracker_setattr(mock_metrics):
|
||||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
||||||
tracker.loss = 2.0
|
tracker.loss = 2.0
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ def test_non_mutate():
|
|||||||
def test_index_error_no_data():
|
def test_index_error_no_data():
|
||||||
buffer, _ = make_new_buffer()
|
buffer, _ = make_new_buffer()
|
||||||
with pytest.raises(IndexError):
|
with pytest.raises(IndexError):
|
||||||
buffer[0]
|
_ = buffer[0]
|
||||||
|
|
||||||
|
|
||||||
def test_index_error_with_data():
|
def test_index_error_with_data():
|
||||||
@@ -83,9 +83,9 @@ def test_index_error_with_data():
|
|||||||
new_data = make_spoof_data_frames(1, n_frames)
|
new_data = make_spoof_data_frames(1, n_frames)
|
||||||
buffer.add_data(new_data)
|
buffer.add_data(new_data)
|
||||||
with pytest.raises(IndexError):
|
with pytest.raises(IndexError):
|
||||||
buffer[n_frames]
|
_ = buffer[n_frames]
|
||||||
with pytest.raises(IndexError):
|
with pytest.raises(IndexError):
|
||||||
buffer[-n_frames - 1]
|
_ = buffer[-n_frames - 1]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("do_reload", [False, True])
|
@pytest.mark.parametrize("do_reload", [False, True])
|
||||||
@@ -185,7 +185,7 @@ def test_delta_timestamps_outside_tolerance_inside_episode_range():
|
|||||||
buffer.add_data(new_data)
|
buffer.add_data(new_data)
|
||||||
buffer.tolerance_s = 0.04
|
buffer.tolerance_s = 0.04
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
buffer[2]
|
_ = buffer[2]
|
||||||
|
|
||||||
|
|
||||||
def test_delta_timestamps_outside_tolerance_outside_episode_range():
|
def test_delta_timestamps_outside_tolerance_outside_episode_range():
|
||||||
@@ -229,6 +229,7 @@ def test_compute_sampler_weights_trivial(
|
|||||||
weights = compute_sampler_weights(
|
weights = compute_sampler_weights(
|
||||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
|
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
|
||||||
)
|
)
|
||||||
|
expected_weights: torch.Tensor = None
|
||||||
if offline_dataset_size == 0 or online_dataset_size == 0:
|
if offline_dataset_size == 0 or online_dataset_size == 0:
|
||||||
expected_weights = torch.ones(offline_dataset_size + online_dataset_size)
|
expected_weights = torch.ones(offline_dataset_size + online_dataset_size)
|
||||||
elif online_sampling_ratio == 0:
|
elif online_sampling_ratio == 0:
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
# pylint: disable=redefined-outer-name, unused-argument
|
||||||
import inspect
|
import inspect
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
# pylint: disable=redefined-outer-name, unused-argument
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|||||||
@@ -32,16 +32,16 @@ from lerobot.common.utils.import_utils import is_package_available
|
|||||||
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu"
|
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
TEST_ROBOT_TYPES = []
|
TEST_ROBOT_TYPES = []
|
||||||
for robot_type in available_robots:
|
for available_robot_type in available_robots:
|
||||||
TEST_ROBOT_TYPES += [(robot_type, True), (robot_type, False)]
|
TEST_ROBOT_TYPES += [(available_robot_type, True), (available_robot_type, False)]
|
||||||
|
|
||||||
TEST_CAMERA_TYPES = []
|
TEST_CAMERA_TYPES = []
|
||||||
for camera_type in available_cameras:
|
for available_camera_type in available_cameras:
|
||||||
TEST_CAMERA_TYPES += [(camera_type, True), (camera_type, False)]
|
TEST_CAMERA_TYPES += [(available_camera_type, True), (available_camera_type, False)]
|
||||||
|
|
||||||
TEST_MOTOR_TYPES = []
|
TEST_MOTOR_TYPES = []
|
||||||
for motor_type in available_motors:
|
for available_motor_type in available_motors:
|
||||||
TEST_MOTOR_TYPES += [(motor_type, True), (motor_type, False)]
|
TEST_MOTOR_TYPES += [(available_motor_type, True), (available_motor_type, False)]
|
||||||
|
|
||||||
# Camera indices used for connecting physical cameras
|
# Camera indices used for connecting physical cameras
|
||||||
OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0))
|
OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0))
|
||||||
@@ -72,7 +72,6 @@ def require_x86_64_kernel(func):
|
|||||||
"""
|
"""
|
||||||
Decorator that skips the test if plateform device is not an x86_64 cpu.
|
Decorator that skips the test if plateform device is not an x86_64 cpu.
|
||||||
"""
|
"""
|
||||||
from functools import wraps
|
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
@@ -87,7 +86,6 @@ def require_cpu(func):
|
|||||||
"""
|
"""
|
||||||
Decorator that skips the test if device is not cpu.
|
Decorator that skips the test if device is not cpu.
|
||||||
"""
|
"""
|
||||||
from functools import wraps
|
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
@@ -102,7 +100,6 @@ def require_cuda(func):
|
|||||||
"""
|
"""
|
||||||
Decorator that skips the test if cuda is not available.
|
Decorator that skips the test if cuda is not available.
|
||||||
"""
|
"""
|
||||||
from functools import wraps
|
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
@@ -288,17 +285,17 @@ def mock_calibration_dir(calibration_dir):
|
|||||||
"motor_names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
"motor_names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||||
}
|
}
|
||||||
Path(str(calibration_dir)).mkdir(parents=True, exist_ok=True)
|
Path(str(calibration_dir)).mkdir(parents=True, exist_ok=True)
|
||||||
with open(calibration_dir / "main_follower.json", "w") as f:
|
with open(calibration_dir / "main_follower.json", "w", encoding="utf-8") as f:
|
||||||
json.dump(example_calib, f)
|
json.dump(example_calib, f)
|
||||||
with open(calibration_dir / "main_leader.json", "w") as f:
|
with open(calibration_dir / "main_leader.json", "w", encoding="utf-8") as f:
|
||||||
json.dump(example_calib, f)
|
json.dump(example_calib, f)
|
||||||
with open(calibration_dir / "left_follower.json", "w") as f:
|
with open(calibration_dir / "left_follower.json", "w", encoding="utf-8") as f:
|
||||||
json.dump(example_calib, f)
|
json.dump(example_calib, f)
|
||||||
with open(calibration_dir / "left_leader.json", "w") as f:
|
with open(calibration_dir / "left_leader.json", "w", encoding="utf-8") as f:
|
||||||
json.dump(example_calib, f)
|
json.dump(example_calib, f)
|
||||||
with open(calibration_dir / "right_follower.json", "w") as f:
|
with open(calibration_dir / "right_follower.json", "w", encoding="utf-8") as f:
|
||||||
json.dump(example_calib, f)
|
json.dump(example_calib, f)
|
||||||
with open(calibration_dir / "right_leader.json", "w") as f:
|
with open(calibration_dir / "right_leader.json", "w", encoding="utf-8") as f:
|
||||||
json.dump(example_calib, f)
|
json.dump(example_calib, f)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user