Compare commits
15 Commits
torchcodec
...
fix/lint_w
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e511e7eda5 | ||
|
|
f5ed3723f0 | ||
|
|
b104be0d04 | ||
|
|
f9e4a1f5c4 | ||
|
|
0eb56cec14 | ||
|
|
e59ef036e1 | ||
|
|
9b380eaf67 | ||
|
|
1187604ba0 | ||
|
|
5c6f2d2cd0 | ||
|
|
652fedf69c | ||
|
|
85214ec303 | ||
|
|
dffa5a18db | ||
|
|
301f152a34 | ||
|
|
0ed08c0b1f | ||
|
|
254bc707e7 |
@@ -67,6 +67,7 @@ def parse_int_or_none(value) -> int | None:
|
||||
def check_datasets_formats(repo_ids: list) -> None:
|
||||
for repo_id in repo_ids:
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
# TODO(Steven): Seems this API has changed
|
||||
if dataset.video:
|
||||
raise ValueError(
|
||||
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
|
||||
|
||||
@@ -222,7 +222,7 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
|
||||
|
||||
if __name__ == "__main__":
|
||||
# To try this script, modify the repo id with your own HuggingFace user (e.g cadene/pusht)
|
||||
repo_id = "lerobot/pusht"
|
||||
repository_id = "lerobot/pusht"
|
||||
|
||||
modes = ["video", "image", "keypoints"]
|
||||
# Uncomment if you want to try with a specific mode
|
||||
@@ -230,13 +230,13 @@ if __name__ == "__main__":
|
||||
# modes = ["image"]
|
||||
# modes = ["keypoints"]
|
||||
|
||||
raw_dir = Path("data/lerobot-raw/pusht_raw")
|
||||
for mode in modes:
|
||||
if mode in ["image", "keypoints"]:
|
||||
repo_id += f"_{mode}"
|
||||
data_dir = Path("data/lerobot-raw/pusht_raw")
|
||||
for available_mode in modes:
|
||||
if available_mode in ["image", "keypoints"]:
|
||||
repository_id += f"_{available_mode}"
|
||||
|
||||
# download and load raw dataset, create LeRobotDataset, populate it, push to hub
|
||||
main(raw_dir, repo_id=repo_id, mode=mode)
|
||||
main(data_dir, repo_id=repository_id, mode=available_mode)
|
||||
|
||||
# Uncomment if you want to load the local dataset and explore it
|
||||
# dataset = LeRobotDataset(repo_id=repo_id)
|
||||
|
||||
@@ -13,8 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from pprint import pformat
|
||||
|
||||
import torch
|
||||
|
||||
@@ -98,17 +96,17 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||
dataset = MultiLeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
# TODO(aliberts): add proper support for multi dataset
|
||||
# delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
)
|
||||
logging.info(
|
||||
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
||||
f"{pformat(dataset.repo_id_to_index, indent=2)}"
|
||||
)
|
||||
# dataset = MultiLeRobotDataset(
|
||||
# cfg.dataset.repo_id,
|
||||
# # TODO(aliberts): add proper support for multi dataset
|
||||
# # delta_timestamps=delta_timestamps,
|
||||
# image_transforms=image_transforms,
|
||||
# video_backend=cfg.dataset.video_backend,
|
||||
# )
|
||||
# logging.info(
|
||||
# "Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
||||
# f"{pformat(dataset.repo_id_to_index, indent=2)}"
|
||||
# )
|
||||
|
||||
if cfg.dataset.use_imagenet_stats:
|
||||
for key in dataset.meta.camera_keys:
|
||||
|
||||
@@ -81,21 +81,21 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
|
||||
print(f"Error writing image {fpath}: {e}")
|
||||
|
||||
|
||||
def worker_thread_loop(queue: queue.Queue):
|
||||
def worker_thread_loop(task_queue: queue.Queue):
|
||||
while True:
|
||||
item = queue.get()
|
||||
item = task_queue.get()
|
||||
if item is None:
|
||||
queue.task_done()
|
||||
task_queue.task_done()
|
||||
break
|
||||
image_array, fpath = item
|
||||
write_image(image_array, fpath)
|
||||
queue.task_done()
|
||||
task_queue.task_done()
|
||||
|
||||
|
||||
def worker_process(queue: queue.Queue, num_threads: int):
|
||||
def worker_process(task_queue: queue.Queue, num_threads: int):
|
||||
threads = []
|
||||
for _ in range(num_threads):
|
||||
t = threading.Thread(target=worker_thread_loop, args=(queue,))
|
||||
t = threading.Thread(target=worker_thread_loop, args=(task_queue,))
|
||||
t.daemon = True
|
||||
t.start()
|
||||
threads.append(t)
|
||||
|
||||
@@ -87,6 +87,7 @@ class LeRobotDatasetMetadata:
|
||||
self.repo_id = repo_id
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||
self.stats = None
|
||||
|
||||
try:
|
||||
if force_cache_sync:
|
||||
@@ -102,10 +103,10 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
def load_metadata(self):
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
check_version_compatibility(self.repo_id, self.version, CODEBASE_VERSION)
|
||||
self.tasks, self.task_to_task_index = load_tasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
if self._version < packaging.version.parse("v2.1"):
|
||||
if self.version < packaging.version.parse("v2.1"):
|
||||
self.stats = load_stats(self.root)
|
||||
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
|
||||
else:
|
||||
@@ -127,7 +128,7 @@ class LeRobotDatasetMetadata:
|
||||
)
|
||||
|
||||
@property
|
||||
def _version(self) -> packaging.version.Version:
|
||||
def version(self) -> packaging.version.Version:
|
||||
"""Codebase version used to create this dataset."""
|
||||
return packaging.version.parse(self.info["codebase_version"])
|
||||
|
||||
@@ -321,8 +322,9 @@ class LeRobotDatasetMetadata:
|
||||
robot_type = robot.robot_type
|
||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||
logging.warning(
|
||||
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks."
|
||||
"Some cameras in your %s robot don't have an fps matching the fps of your dataset."
|
||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks.",
|
||||
robot.robot_type,
|
||||
)
|
||||
elif features is None:
|
||||
raise ValueError(
|
||||
@@ -486,7 +488,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.meta = LeRobotDatasetMetadata(
|
||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
|
||||
if self.episodes is not None and self.meta.version >= packaging.version.parse("v2.1"):
|
||||
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
||||
self.stats = aggregate_stats(episodes_stats)
|
||||
|
||||
@@ -518,7 +520,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self,
|
||||
branch: str | None = None,
|
||||
tags: list | None = None,
|
||||
license: str | None = "apache-2.0",
|
||||
dataset_license: str | None = "apache-2.0",
|
||||
tag_version: bool = True,
|
||||
push_videos: bool = True,
|
||||
private: bool = False,
|
||||
@@ -561,7 +563,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
||||
tags=tags, dataset_info=self.meta.info, license=dataset_license, **card_kwargs
|
||||
)
|
||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
||||
|
||||
@@ -842,6 +844,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
|
||||
None.
|
||||
"""
|
||||
episode_buffer = None
|
||||
if not episode_data:
|
||||
episode_buffer = self.episode_buffer
|
||||
|
||||
@@ -1086,8 +1089,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
||||
extra_keys = set(ds.features).difference(intersection_features)
|
||||
logging.warning(
|
||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||
"other datasets."
|
||||
"keys %s of %s were disabled as they are not contained in all the other datasets.",
|
||||
extra_keys,
|
||||
repo_id,
|
||||
)
|
||||
self.disabled_features.update(extra_keys)
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ def rechunk_recompress_array(group, name, chunks=None, chunk_length=None, compre
|
||||
# rechunk recompress
|
||||
group.move(name, tmp_key)
|
||||
old_arr = group[tmp_key]
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
||||
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy(
|
||||
source=old_arr,
|
||||
dest=group,
|
||||
name=name,
|
||||
@@ -192,7 +192,7 @@ class ReplayBuffer:
|
||||
else:
|
||||
root = zarr.group(store=store)
|
||||
# copy without recompression
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
|
||||
source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists
|
||||
)
|
||||
data_group = root.create_group("data", overwrite=True)
|
||||
@@ -205,7 +205,7 @@ class ReplayBuffer:
|
||||
if cks == value.chunks and cpr == value.compressor:
|
||||
# copy without recompression
|
||||
this_path = "/data/" + key
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
|
||||
source=src_store,
|
||||
dest=store,
|
||||
source_path=this_path,
|
||||
@@ -214,7 +214,7 @@ class ReplayBuffer:
|
||||
)
|
||||
else:
|
||||
# copy with recompression
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
||||
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy(
|
||||
source=value,
|
||||
dest=data_group,
|
||||
name=key,
|
||||
@@ -275,7 +275,7 @@ class ReplayBuffer:
|
||||
compressors = {}
|
||||
if self.backend == "zarr":
|
||||
# recompression free copy
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
|
||||
source=self.root.store,
|
||||
dest=store,
|
||||
source_path="/meta",
|
||||
@@ -297,7 +297,7 @@ class ReplayBuffer:
|
||||
if cks == value.chunks and cpr == value.compressor:
|
||||
# copy without recompression
|
||||
this_path = "/data/" + key
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
|
||||
source=self.root.store,
|
||||
dest=store,
|
||||
source_path=this_path,
|
||||
|
||||
@@ -162,9 +162,9 @@ def download_raw(raw_dir: Path, repo_id: str):
|
||||
)
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}")
|
||||
logging.info("Start downloading from huggingface.co/%s for %s", user_id, dataset_id)
|
||||
snapshot_download(repo_id, repo_type="dataset", local_dir=raw_dir)
|
||||
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")
|
||||
logging.info("Finish downloading from huggingface.co/%s for %s", user_id, dataset_id)
|
||||
|
||||
|
||||
def download_all_raw_datasets(data_dir: Path | None = None):
|
||||
|
||||
@@ -72,7 +72,7 @@ def check_format(raw_dir) -> bool:
|
||||
assert data[f"/observations/images/{camera}"].ndim == 2
|
||||
else:
|
||||
assert data[f"/observations/images/{camera}"].ndim == 4
|
||||
b, h, w, c = data[f"/observations/images/{camera}"].shape
|
||||
_, h, w, c = data[f"/observations/images/{camera}"].shape
|
||||
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
||||
|
||||
|
||||
@@ -103,6 +103,7 @@ def load_from_raw(
|
||||
|
||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||
action = torch.from_numpy(ep["/action"][:])
|
||||
velocity = None
|
||||
if "/observations/qvel" in ep:
|
||||
velocity = torch.from_numpy(ep["/observations/qvel"][:])
|
||||
if "/observations/effort" in ep:
|
||||
|
||||
@@ -96,6 +96,7 @@ def from_raw_to_lerobot_format(
|
||||
if fps is None:
|
||||
fps = 30
|
||||
|
||||
# TODO(Steven): Is this meant to call cam_png_format.load_from_raw?
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
|
||||
@@ -42,7 +42,9 @@ def check_format(raw_dir) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||
def load_from_raw(
|
||||
raw_dir: Path, videos_dir: Path, fps: int, _video: bool, _episodes: list[int] | None = None
|
||||
):
|
||||
# Load data stream that will be used as reference for the timestamps synchronization
|
||||
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
|
||||
if len(reference_files) == 0:
|
||||
|
||||
@@ -55,7 +55,7 @@ def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers
|
||||
|
||||
num_images = len(imgs_array)
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
|
||||
_ = [executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
|
||||
|
||||
|
||||
def get_default_encoding() -> dict:
|
||||
@@ -92,24 +92,23 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
current_episode = None
|
||||
"""
|
||||
The episode_index is a list of integers, each representing the episode index of the corresponding example.
|
||||
For instance, the following is a valid episode_index:
|
||||
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
|
||||
|
||||
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
|
||||
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
|
||||
{
|
||||
"from": [0, 3, 7],
|
||||
"to": [3, 7, 12]
|
||||
}
|
||||
"""
|
||||
# The episode_index is a list of integers, each representing the episode index of the corresponding example.
|
||||
# For instance, the following is a valid episode_index:
|
||||
# [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
|
||||
#
|
||||
# Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
|
||||
# ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
|
||||
# {
|
||||
# "from": [0, 3, 7],
|
||||
# "to": [3, 7, 12]
|
||||
# }
|
||||
if len(hf_dataset) == 0:
|
||||
episode_data_index = {
|
||||
"from": torch.tensor([]),
|
||||
"to": torch.tensor([]),
|
||||
}
|
||||
return episode_data_index
|
||||
idx = None
|
||||
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
|
||||
if episode_idx != current_episode:
|
||||
# We encountered a new episode, so we append its starting location to the "from" list
|
||||
|
||||
@@ -23,6 +23,7 @@ from torchvision.transforms.v2 import Transform
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
|
||||
|
||||
# TODO(Steven): Missing transform() implementation
|
||||
class RandomSubsetApply(Transform):
|
||||
"""Apply a random subset of N transformations from a list of transformations.
|
||||
|
||||
@@ -218,6 +219,7 @@ def make_transform_from_config(cfg: ImageTransformConfig):
|
||||
raise ValueError(f"Transform '{cfg.type}' is not valid.")
|
||||
|
||||
|
||||
# TODO(Steven): Missing transform() implementation
|
||||
class ImageTransforms(Transform):
|
||||
"""A class to compose image transforms based on configuration."""
|
||||
|
||||
|
||||
@@ -135,21 +135,21 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
||||
|
||||
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
|
||||
# Embed image bytes into the table before saving to parquet
|
||||
format = dataset.format
|
||||
ds_format = dataset.format
|
||||
dataset = dataset.with_format("arrow")
|
||||
dataset = dataset.map(embed_table_storage, batched=False)
|
||||
dataset = dataset.with_format(**format)
|
||||
dataset = dataset.with_format(**ds_format)
|
||||
return dataset
|
||||
|
||||
|
||||
def load_json(fpath: Path) -> Any:
|
||||
with open(fpath) as f:
|
||||
with open(fpath, encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json(data: dict, fpath: Path) -> None:
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(fpath, "w") as f:
|
||||
with open(fpath, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
@@ -300,7 +300,7 @@ def check_version_compatibility(
|
||||
if v_check.major < v_current.major and enforce_breaking_major:
|
||||
raise BackwardCompatibilityError(repo_id, v_check)
|
||||
elif v_check.minor < v_current.minor:
|
||||
logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check))
|
||||
logging.warning("%s", V21_MESSAGE.format(repo_id=repo_id, version=v_check))
|
||||
|
||||
|
||||
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
|
||||
@@ -348,7 +348,9 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
|
||||
if compatibles:
|
||||
return_version = max(compatibles)
|
||||
if return_version < target_version:
|
||||
logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}")
|
||||
logging.warning(
|
||||
"Revision %s for %s not found, using version v%s", version, repo_id, return_version
|
||||
)
|
||||
return f"v{return_version}"
|
||||
|
||||
lower_major = [v for v in hub_versions if v.major < target_version.major]
|
||||
@@ -403,7 +405,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
for key, ft in features.items():
|
||||
shape = ft["shape"]
|
||||
if ft["dtype"] in ["image", "video"]:
|
||||
type = FeatureType.VISUAL
|
||||
feature_type = FeatureType.VISUAL
|
||||
if len(shape) != 3:
|
||||
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
|
||||
|
||||
@@ -412,16 +414,16 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
elif key == "observation.environment_state":
|
||||
type = FeatureType.ENV
|
||||
feature_type = FeatureType.ENV
|
||||
elif key.startswith("observation"):
|
||||
type = FeatureType.STATE
|
||||
feature_type = FeatureType.STATE
|
||||
elif key == "action":
|
||||
type = FeatureType.ACTION
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
continue
|
||||
|
||||
policy_features[key] = PolicyFeature(
|
||||
type=type,
|
||||
type=feature_type,
|
||||
shape=shape,
|
||||
)
|
||||
|
||||
|
||||
@@ -871,11 +871,11 @@ def batch_convert():
|
||||
try:
|
||||
convert_dataset(repo_id, LOCAL_DIR, **kwargs)
|
||||
status = f"{repo_id}: success."
|
||||
with open(logfile, "a") as file:
|
||||
with open(logfile, "a", encoding="utf-8") as file:
|
||||
file.write(status + "\n")
|
||||
except Exception:
|
||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||
with open(logfile, "a") as file:
|
||||
with open(logfile, "a", encoding="utf-8") as file:
|
||||
file.write(status + "\n")
|
||||
continue
|
||||
|
||||
|
||||
@@ -190,11 +190,11 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
|
||||
|
||||
json_path = v2_dir / STATS_PATH
|
||||
json_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(json_path, "w") as f:
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(serialized_stats, f, indent=4)
|
||||
|
||||
# Sanity check
|
||||
with open(json_path) as f:
|
||||
with open(json_path, encoding="utf-8") as f:
|
||||
stats_json = json.load(f)
|
||||
|
||||
stats_json = flatten_dict(stats_json)
|
||||
@@ -213,7 +213,7 @@ def get_features_from_hf_dataset(
|
||||
dtype = ft.dtype
|
||||
shape = (1,)
|
||||
names = None
|
||||
if isinstance(ft, datasets.Sequence):
|
||||
elif isinstance(ft, datasets.Sequence):
|
||||
assert isinstance(ft.feature, datasets.Value)
|
||||
dtype = ft.feature.dtype
|
||||
shape = (ft.length,)
|
||||
@@ -232,6 +232,8 @@ def get_features_from_hf_dataset(
|
||||
dtype = "video"
|
||||
shape = None # Add shape later
|
||||
names = ["height", "width", "channels"]
|
||||
else:
|
||||
raise NotImplementedError(f"Feature type {ft._type} not supported.")
|
||||
|
||||
features[key] = {
|
||||
"dtype": dtype,
|
||||
@@ -358,9 +360,9 @@ def move_videos(
|
||||
if len(video_dirs) == 1:
|
||||
video_path = video_dirs[0] / video_file
|
||||
else:
|
||||
for dir in video_dirs:
|
||||
if (dir / video_file).is_file():
|
||||
video_path = dir / video_file
|
||||
for v_dir in video_dirs:
|
||||
if (v_dir / video_file).is_file():
|
||||
video_path = v_dir / video_file
|
||||
break
|
||||
|
||||
video_path.rename(work_dir / target_path)
|
||||
@@ -652,6 +654,7 @@ def main():
|
||||
if not args.local_dir:
|
||||
args.local_dir = Path("/tmp/lerobot_dataset_v2")
|
||||
|
||||
robot_config = None
|
||||
if args.robot is not None:
|
||||
robot_config = make_robot_config(args.robot)
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ def fix_dataset(repo_id: str) -> str:
|
||||
return f"{repo_id}: skipped (no diff)"
|
||||
|
||||
if diff_meta_parquet:
|
||||
logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
|
||||
logging.warning("In info.json not in parquet: %s", meta_features - parquet_features)
|
||||
assert diff_meta_parquet == {"language_instruction"}
|
||||
lerobot_metadata.features.pop("language_instruction")
|
||||
write_info(lerobot_metadata.info, lerobot_metadata.root)
|
||||
@@ -79,7 +79,7 @@ def batch_fix():
|
||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||
|
||||
logging.info(status)
|
||||
with open(logfile, "a") as file:
|
||||
with open(logfile, "a", encoding="utf-8") as file:
|
||||
file.write(status + "\n")
|
||||
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ def batch_convert():
|
||||
except Exception:
|
||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||
|
||||
with open(logfile, "a") as file:
|
||||
with open(logfile, "a", encoding="utf-8") as file:
|
||||
file.write(status + "\n")
|
||||
|
||||
|
||||
|
||||
@@ -45,6 +45,9 @@ V21 = "v2.1"
|
||||
|
||||
|
||||
class SuppressWarnings:
|
||||
def __init__(self):
|
||||
self.previous_level = None
|
||||
|
||||
def __enter__(self):
|
||||
self.previous_level = logging.getLogger().getEffectiveLevel()
|
||||
logging.getLogger().setLevel(logging.ERROR)
|
||||
|
||||
@@ -83,7 +83,7 @@ def decode_video_frames_torchvision(
|
||||
for frame in reader:
|
||||
current_ts = frame["pts"]
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"frame loaded at timestamp={current_ts:.4f}")
|
||||
logging.info("frame loaded at timestamp=%.4f", current_ts)
|
||||
loaded_frames.append(frame["data"])
|
||||
loaded_ts.append(current_ts)
|
||||
if current_ts >= last_ts:
|
||||
@@ -118,7 +118,7 @@ def decode_video_frames_torchvision(
|
||||
closest_ts = loaded_ts[argmin_]
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"{closest_ts=}")
|
||||
logging.info("closest_ts=%s", closest_ts)
|
||||
|
||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||
closest_frames = closest_frames.type(torch.float32) / 255
|
||||
@@ -227,7 +227,9 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
result = subprocess.run(
|
||||
ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
@@ -263,7 +265,9 @@ def get_video_info(video_path: Path | str) -> dict:
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
result = subprocess.run(
|
||||
ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
|
||||
@@ -32,7 +32,8 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@abc.abstractproperty
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def gym_kwargs(self) -> dict:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
return "adam"
|
||||
|
||||
@abc.abstractmethod
|
||||
def build(self) -> torch.optim.Optimizer:
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@@ -140,7 +140,7 @@ class ACTConfig(PreTrainedConfig):
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
# Input validation (not exhaustive).
|
||||
if not self.vision_backbone.startswith("resnet"):
|
||||
raise ValueError(
|
||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||
|
||||
@@ -222,6 +222,8 @@ class ACTTemporalEnsembler:
|
||||
self.chunk_size = chunk_size
|
||||
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
|
||||
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
|
||||
self.ensembled_actions = None
|
||||
self.ensembled_actions_count = None
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
|
||||
@@ -162,7 +162,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
# Input validation (not exhaustive).
|
||||
if not self.vision_backbone.startswith("resnet"):
|
||||
raise ValueError(
|
||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||
|
||||
@@ -170,6 +170,7 @@ def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMSche
|
||||
raise ValueError(f"Unsupported noise scheduler type {name}")
|
||||
|
||||
|
||||
# TODO(Steven): Missing forward() implementation
|
||||
class DiffusionModel(nn.Module):
|
||||
def __init__(self, config: DiffusionConfig):
|
||||
super().__init__()
|
||||
@@ -203,6 +204,7 @@ class DiffusionModel(nn.Module):
|
||||
)
|
||||
|
||||
if config.num_inference_steps is None:
|
||||
# TODO(Steven): Consider type check?
|
||||
self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps
|
||||
else:
|
||||
self.num_inference_steps = config.num_inference_steps
|
||||
@@ -333,7 +335,7 @@ class DiffusionModel(nn.Module):
|
||||
# Sample a random noising timestep for each item in the batch.
|
||||
timesteps = torch.randint(
|
||||
low=0,
|
||||
high=self.noise_scheduler.config.num_train_timesteps,
|
||||
high=self.noise_scheduler.config.num_train_timesteps, # TODO(Steven): Consider type check?
|
||||
size=(trajectory.shape[0],),
|
||||
device=trajectory.device,
|
||||
).long()
|
||||
|
||||
@@ -69,12 +69,12 @@ def create_stats_buffers(
|
||||
}
|
||||
)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
max = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
min_norm = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
max_norm = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"min": nn.Parameter(min, requires_grad=False),
|
||||
"max": nn.Parameter(max, requires_grad=False),
|
||||
"min": nn.Parameter(min_norm, requires_grad=False),
|
||||
"max": nn.Parameter(max_norm, requires_grad=False),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -170,12 +170,12 @@ class Normalize(nn.Module):
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
min_norm = buffer["min"]
|
||||
max_norm = buffer["max"]
|
||||
assert not torch.isinf(min_norm).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max_norm).any(), _no_stats_error_str("max")
|
||||
# normalize to [0,1]
|
||||
batch[key] = (batch[key] - min) / (max - min + 1e-8)
|
||||
batch[key] = (batch[key] - min_norm) / (max_norm - min_norm + 1e-8)
|
||||
# normalize to [-1, 1]
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
else:
|
||||
@@ -243,12 +243,12 @@ class Unnormalize(nn.Module):
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = batch[key] * std + mean
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
min_norm = buffer["min"]
|
||||
max_norm = buffer["max"]
|
||||
assert not torch.isinf(min_norm).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max_norm).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max - min) + min
|
||||
batch[key] = batch[key] * (max_norm - min_norm) + min_norm
|
||||
else:
|
||||
raise ValueError(norm_mode)
|
||||
return batch
|
||||
|
||||
@@ -91,7 +91,7 @@ class PI0Config(PreTrainedConfig):
|
||||
super().__post_init__()
|
||||
|
||||
# TODO(Steven): Validate device and amp? in all policy configs?
|
||||
"""Input validation (not exhaustive)."""
|
||||
# Input validation (not exhaustive).
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||
|
||||
@@ -55,7 +55,7 @@ def main():
|
||||
with open(save_dir / "noise.pkl", "rb") as f:
|
||||
noise = pickle.load(f)
|
||||
|
||||
with open(ckpt_jax_dir / "assets/norm_stats.json") as f:
|
||||
with open(ckpt_jax_dir / "assets/norm_stats.json", encoding="utf-8") as f:
|
||||
norm_stats = json.load(f)
|
||||
|
||||
# Override stats
|
||||
|
||||
@@ -318,7 +318,7 @@ def update_keys_with_prefix(d: dict, prefix: str) -> dict:
|
||||
return {f"{prefix}{key}": value for key, value in d.items()}
|
||||
|
||||
|
||||
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str):
|
||||
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, _tokenizer_id: str, output_path: str):
|
||||
# Break down orbax ckpts - they are in OCDBT
|
||||
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
|
||||
# process projection params
|
||||
@@ -432,6 +432,6 @@ if __name__ == "__main__":
|
||||
convert_pi0_checkpoint(
|
||||
checkpoint_dir=args.checkpoint_dir,
|
||||
precision=args.precision,
|
||||
tokenizer_id=args.tokenizer_hub_id,
|
||||
_tokenizer_id=args.tokenizer_hub_id,
|
||||
output_path=args.output_path,
|
||||
)
|
||||
|
||||
@@ -16,6 +16,7 @@ import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from packaging.version import Version
|
||||
|
||||
# TODO(Steven): Consider settings this a dependency constraint
|
||||
if Version(torch.__version__) > Version("2.5.0"):
|
||||
# Ffex attention is only available from torch 2.5 onwards
|
||||
from torch.nn.attention.flex_attention import (
|
||||
@@ -121,7 +122,7 @@ def flex_attention_forward(
|
||||
)
|
||||
|
||||
# mask is applied inside the kernel, ideally more efficiently than score_mod.
|
||||
attn_output, attention_weights = flex_attention(
|
||||
attn_output, _attention_weights = flex_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
|
||||
@@ -162,7 +162,7 @@ class TDMPCConfig(PreTrainedConfig):
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
# Input validation (not exhaustive).
|
||||
if self.n_gaussian_samples <= 0:
|
||||
raise ValueError(
|
||||
f"The number of gaussian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
|
||||
|
||||
@@ -88,6 +88,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
for param in self.model_target.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self._queues = None
|
||||
self._prev_mean: torch.Tensor | None = None
|
||||
|
||||
self.reset()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
@@ -108,7 +111,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
self._queues["observation.environment_state"] = deque(maxlen=1)
|
||||
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
|
||||
# CEM for the next step.
|
||||
self._prev_mean: torch.Tensor | None = None
|
||||
self._prev_mean = None
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
@@ -514,6 +517,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
|
||||
|
||||
|
||||
# TODO(Steven): forward implementation missing
|
||||
class TDMPCTOLD(nn.Module):
|
||||
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
|
||||
|
||||
|
||||
@@ -144,7 +144,7 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
# Input validation (not exhaustive).
|
||||
if not self.vision_backbone.startswith("resnet"):
|
||||
raise ValueError(
|
||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||
|
||||
@@ -70,6 +70,8 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
|
||||
self.vqbet = VQBeTModel(config)
|
||||
|
||||
self._queues = None
|
||||
|
||||
self.reset()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
@@ -535,7 +537,7 @@ class VQBeTHead(nn.Module):
|
||||
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers
|
||||
)
|
||||
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
|
||||
NT, G, choices = cbet_probs.shape
|
||||
NT, _G, choices = cbet_probs.shape
|
||||
sampled_centers = einops.rearrange(
|
||||
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
|
||||
"(NT G) 1 -> NT G",
|
||||
@@ -578,7 +580,7 @@ class VQBeTHead(nn.Module):
|
||||
"decoded_action": decoded_action,
|
||||
}
|
||||
|
||||
def loss_fn(self, pred, target, **kwargs):
|
||||
def loss_fn(self, pred, target, **_kwargs):
|
||||
"""
|
||||
for given ground truth action values (target), and prediction (pred) this function calculates the overall loss.
|
||||
|
||||
@@ -605,7 +607,7 @@ class VQBeTHead(nn.Module):
|
||||
# Figure out the loss for the actions.
|
||||
# First, we need to find the closest cluster center for each ground truth action.
|
||||
with torch.no_grad():
|
||||
state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
|
||||
_state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
|
||||
|
||||
# Now we can compute the loss.
|
||||
|
||||
@@ -762,6 +764,7 @@ def _replace_submodules(
|
||||
return root_module
|
||||
|
||||
|
||||
# TODO(Steven): Missing implementation of forward, is it maybe vqvae_forward?
|
||||
class VqVae(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -876,13 +879,13 @@ class FocalLoss(nn.Module):
|
||||
self.gamma = gamma
|
||||
self.size_average = size_average
|
||||
|
||||
def forward(self, input, target):
|
||||
if len(input.shape) == 3:
|
||||
N, T, _ = input.shape
|
||||
logpt = F.log_softmax(input, dim=-1)
|
||||
def forward(self, forward_input, target):
|
||||
if len(forward_input.shape) == 3:
|
||||
N, T, _ = forward_input.shape
|
||||
logpt = F.log_softmax(forward_input, dim=-1)
|
||||
logpt = logpt.gather(-1, target.view(N, T, 1)).view(N, T)
|
||||
elif len(input.shape) == 2:
|
||||
logpt = F.log_softmax(input, dim=-1)
|
||||
elif len(forward_input.shape) == 2:
|
||||
logpt = F.log_softmax(forward_input, dim=-1)
|
||||
logpt = logpt.gather(-1, target.view(-1, 1)).view(-1)
|
||||
pt = logpt.exp()
|
||||
|
||||
|
||||
@@ -34,63 +34,58 @@ from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
|
||||
# ruff: noqa: N806
|
||||
|
||||
"""
|
||||
This file is part of a VQ-BeT that utilizes code from the following repositories:
|
||||
# This file is part of a VQ-BeT that utilizes code from the following repositories:
|
||||
#
|
||||
# - Vector Quantize PyTorch code is licensed under the MIT License:
|
||||
# Original source: https://github.com/lucidrains/vector-quantize-pytorch
|
||||
#
|
||||
# - nanoGPT part is an adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch.
|
||||
# Original source: https://github.com/karpathy/nanoGPT
|
||||
#
|
||||
# We also made some changes to the original code to adapt it to our needs. The changes are described in the code below.
|
||||
|
||||
- Vector Quantize PyTorch code is licensed under the MIT License:
|
||||
Original source: https://github.com/lucidrains/vector-quantize-pytorch
|
||||
|
||||
- nanoGPT part is an adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch.
|
||||
Original source: https://github.com/karpathy/nanoGPT
|
||||
|
||||
We also made some changes to the original code to adapt it to our needs. The changes are described in the code below.
|
||||
"""
|
||||
|
||||
"""
|
||||
This is a part for nanoGPT that utilizes code from the following repository:
|
||||
|
||||
- Andrej Karpathy's nanoGPT implementation in PyTorch.
|
||||
Original source: https://github.com/karpathy/nanoGPT
|
||||
|
||||
- The nanoGPT code is licensed under the MIT License:
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 Andrej Karpathy
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
- We've made some changes to the original code to adapt it to our needs.
|
||||
|
||||
Changed variable names:
|
||||
- n_head -> gpt_n_head
|
||||
- n_embd -> gpt_hidden_dim
|
||||
- block_size -> gpt_block_size
|
||||
- n_layer -> gpt_n_layer
|
||||
|
||||
|
||||
class GPT(nn.Module):
|
||||
- removed unused functions `def generate`, `def estimate_mfu`, and `def from_pretrained`
|
||||
- changed the `configure_optimizers` to `def configure_parameters` and made it to return only the parameters of the model: we use an external optimizer in our training loop.
|
||||
- in the function `forward`, we removed target loss calculation parts, since it will be calculated in the training loop (after passing through bin prediction and offset prediction heads).
|
||||
|
||||
"""
|
||||
# This is a part for nanoGPT that utilizes code from the following repository:
|
||||
#
|
||||
# - Andrej Karpathy's nanoGPT implementation in PyTorch.
|
||||
# Original source: https://github.com/karpathy/nanoGPT
|
||||
#
|
||||
# - The nanoGPT code is licensed under the MIT License:
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2022 Andrej Karpathy
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
#
|
||||
# - We've made some changes to the original code to adapt it to our needs.
|
||||
#
|
||||
# Changed variable names:
|
||||
# - n_head -> gpt_n_head
|
||||
# - n_embd -> gpt_hidden_dim
|
||||
# - block_size -> gpt_block_size
|
||||
# - n_layer -> gpt_n_layer
|
||||
#
|
||||
#
|
||||
# class GPT(nn.Module):
|
||||
# - removed unused functions `def generate`, `def estimate_mfu`, and `def from_pretrained`
|
||||
# - changed the `configure_optimizers` to `def configure_parameters` and made it to return only the parameters of the model: we use an external optimizer in our training loop.
|
||||
# - in the function `forward`, we removed target loss calculation parts, since it will be calculated in the training loop (after passing through bin prediction and offset prediction heads).
|
||||
|
||||
|
||||
class CausalSelfAttention(nn.Module):
|
||||
@@ -200,9 +195,9 @@ class GPT(nn.Module):
|
||||
n_params = sum(p.numel() for p in self.parameters())
|
||||
print("number of parameters: {:.2f}M".format(n_params / 1e6))
|
||||
|
||||
def forward(self, input, targets=None):
|
||||
device = input.device
|
||||
b, t, d = input.size()
|
||||
def forward(self, forward_input):
|
||||
device = forward_input.device
|
||||
_, t, _ = forward_input.size()
|
||||
assert t <= self.config.gpt_block_size, (
|
||||
f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
|
||||
)
|
||||
@@ -211,7 +206,7 @@ class GPT(nn.Module):
|
||||
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
|
||||
|
||||
# forward the GPT model itself
|
||||
tok_emb = self.transformer.wte(input) # token embeddings of shape (b, t, gpt_hidden_dim)
|
||||
tok_emb = self.transformer.wte(forward_input) # token embeddings of shape (b, t, gpt_hidden_dim)
|
||||
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim)
|
||||
x = self.transformer.drop(tok_emb + pos_emb)
|
||||
for block in self.transformer.h:
|
||||
@@ -285,51 +280,48 @@ class GPT(nn.Module):
|
||||
return decay, no_decay
|
||||
|
||||
|
||||
"""
|
||||
This file is a part for Residual Vector Quantization that utilizes code from the following repository:
|
||||
|
||||
- Phil Wang's vector-quantize-pytorch implementation in PyTorch.
|
||||
Original source: https://github.com/lucidrains/vector-quantize-pytorch
|
||||
|
||||
- The vector-quantize-pytorch code is licensed under the MIT License:
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2020 Phil Wang
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
- We've made some changes to the original code to adapt it to our needs.
|
||||
|
||||
class ResidualVQ(nn.Module):
|
||||
- added `self.register_buffer('freeze_codebook', torch.tensor(False))` to the __init__ method:
|
||||
This enables the user to save an indicator whether the codebook is frozen or not.
|
||||
- changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`:
|
||||
This is to make the function name more descriptive.
|
||||
|
||||
class VectorQuantize(nn.Module):
|
||||
- removed the `use_cosine_sim` and `layernorm_after_project_in` parameters from the __init__ method:
|
||||
These parameters are not used in the code.
|
||||
- changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`:
|
||||
This is to make the function name more descriptive.
|
||||
|
||||
"""
|
||||
# This file is a part for Residual Vector Quantization that utilizes code from the following repository:
|
||||
#
|
||||
# - Phil Wang's vector-quantize-pytorch implementation in PyTorch.
|
||||
# Original source: https://github.com/lucidrains/vector-quantize-pytorch
|
||||
#
|
||||
# - The vector-quantize-pytorch code is licensed under the MIT License:
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2020 Phil Wang
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
#
|
||||
# - We've made some changes to the original code to adapt it to our needs.
|
||||
#
|
||||
# class ResidualVQ(nn.Module):
|
||||
# - added `self.register_buffer('freeze_codebook', torch.tensor(False))` to the __init__ method:
|
||||
# This enables the user to save an indicator whether the codebook is frozen or not.
|
||||
# - changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`:
|
||||
# This is to make the function name more descriptive.
|
||||
#
|
||||
# class VectorQuantize(nn.Module):
|
||||
# - removed the `use_cosine_sim` and `layernorm_after_project_in` parameters from the __init__ method:
|
||||
# These parameters are not used in the code.
|
||||
# - changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`:
|
||||
# This is to make the function name more descriptive.
|
||||
|
||||
|
||||
class ResidualVQ(nn.Module):
|
||||
@@ -479,6 +471,9 @@ class ResidualVQ(nn.Module):
|
||||
|
||||
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
|
||||
|
||||
null_indices = None
|
||||
null_loss = None
|
||||
|
||||
# sample a layer index at which to dropout further residual quantization
|
||||
# also prepare null indices and loss
|
||||
|
||||
@@ -933,7 +928,7 @@ class VectorQuantize(nn.Module):
|
||||
return quantize, embed_ind, loss
|
||||
|
||||
|
||||
def noop(*args, **kwargs):
|
||||
def noop(*_args, **_kwargs):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@@ -77,9 +77,9 @@ def save_image(img_array, serial_number, frame_index, images_dir):
|
||||
path = images_dir / f"camera_{serial_number}_frame_{frame_index:06d}.png"
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
img.save(str(path), quality=100)
|
||||
logging.info(f"Saved image: {path}")
|
||||
logging.info("Saved image: %s", path)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to save image for camera {serial_number} frame {frame_index}: {e}")
|
||||
logging.error("Failed to save image for camera %s frame %s: %s", serial_number, frame_index, e)
|
||||
|
||||
|
||||
def save_images_from_cameras(
|
||||
@@ -447,7 +447,7 @@ class IntelRealSenseCamera:
|
||||
num_tries += 1
|
||||
time.sleep(1 / self.fps)
|
||||
if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()):
|
||||
raise Exception(
|
||||
raise TimeoutError(
|
||||
"The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called."
|
||||
)
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
MAX_OPENCV_INDEX = 60
|
||||
|
||||
|
||||
def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
|
||||
def find_cameras(max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
|
||||
cameras = []
|
||||
if platform.system() == "Linux":
|
||||
print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports")
|
||||
|
||||
@@ -169,7 +169,8 @@ def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str
|
||||
return steps
|
||||
|
||||
|
||||
def convert_to_bytes(value, bytes, mock=False):
|
||||
# TODO(Steven): Similar function in feetch.py, should be moved to a common place.
|
||||
def convert_to_bytes(value, byte, mock=False):
|
||||
if mock:
|
||||
return value
|
||||
|
||||
@@ -177,16 +178,16 @@ def convert_to_bytes(value, bytes, mock=False):
|
||||
|
||||
# Note: No need to convert back into unsigned int, since this byte preprocessing
|
||||
# already handles it for us.
|
||||
if bytes == 1:
|
||||
if byte == 1:
|
||||
data = [
|
||||
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
|
||||
]
|
||||
elif bytes == 2:
|
||||
elif byte == 2:
|
||||
data = [
|
||||
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
|
||||
dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
|
||||
]
|
||||
elif bytes == 4:
|
||||
elif byte == 4:
|
||||
data = [
|
||||
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
|
||||
dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
|
||||
@@ -196,7 +197,7 @@ def convert_to_bytes(value, bytes, mock=False):
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
|
||||
f"{bytes} is provided instead."
|
||||
f"{byte} is provided instead."
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -228,9 +229,9 @@ def assert_same_address(model_ctrl_table, motor_models, data_name):
|
||||
all_addr = []
|
||||
all_bytes = []
|
||||
for model in motor_models:
|
||||
addr, bytes = model_ctrl_table[model][data_name]
|
||||
addr, byte = model_ctrl_table[model][data_name]
|
||||
all_addr.append(addr)
|
||||
all_bytes.append(bytes)
|
||||
all_bytes.append(byte)
|
||||
|
||||
if len(set(all_addr)) != 1:
|
||||
raise NotImplementedError(
|
||||
@@ -576,6 +577,8 @@ class DynamixelMotorsBus:
|
||||
# (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution
|
||||
low_factor = (start_pos - values[i]) / resolution
|
||||
upp_factor = (end_pos - values[i]) / resolution
|
||||
else:
|
||||
raise ValueError(f"Unknown calibration mode '{calib_mode}'.")
|
||||
|
||||
if not in_range:
|
||||
# Get first integer between the two bounds
|
||||
@@ -596,10 +599,15 @@ class DynamixelMotorsBus:
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
else:
|
||||
raise ValueError(f"Unknown calibration mode '{calib_mode}'.")
|
||||
|
||||
logging.warning(
|
||||
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
|
||||
f"from '{out_of_range_str}' to '{in_range_str}'."
|
||||
"Auto-correct calibration of motor '%s' by shifting value by {abs(factor)} full turns, "
|
||||
"from '%s' to '%s'.",
|
||||
name,
|
||||
out_of_range_str,
|
||||
in_range_str,
|
||||
)
|
||||
|
||||
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
||||
@@ -656,8 +664,8 @@ class DynamixelMotorsBus:
|
||||
motor_ids = [motor_ids]
|
||||
|
||||
assert_same_address(self.model_ctrl_table, self.motor_models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
||||
group = dxl.GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes)
|
||||
addr, byte = self.model_ctrl_table[motor_models[0]][data_name]
|
||||
group = dxl.GroupSyncRead(self.port_handler, self.packet_handler, addr, byte)
|
||||
for idx in motor_ids:
|
||||
group.addParam(idx)
|
||||
|
||||
@@ -674,7 +682,7 @@ class DynamixelMotorsBus:
|
||||
|
||||
values = []
|
||||
for idx in motor_ids:
|
||||
value = group.getData(idx, addr, bytes)
|
||||
value = group.getData(idx, addr, byte)
|
||||
values.append(value)
|
||||
|
||||
if return_list:
|
||||
@@ -709,13 +717,13 @@ class DynamixelMotorsBus:
|
||||
models.append(model)
|
||||
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[model][data_name]
|
||||
addr, byte = self.model_ctrl_table[model][data_name]
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
|
||||
if data_name not in self.group_readers:
|
||||
# create new group reader
|
||||
self.group_readers[group_key] = dxl.GroupSyncRead(
|
||||
self.port_handler, self.packet_handler, addr, bytes
|
||||
self.port_handler, self.packet_handler, addr, byte
|
||||
)
|
||||
for idx in motor_ids:
|
||||
self.group_readers[group_key].addParam(idx)
|
||||
@@ -733,7 +741,7 @@ class DynamixelMotorsBus:
|
||||
|
||||
values = []
|
||||
for idx in motor_ids:
|
||||
value = self.group_readers[group_key].getData(idx, addr, bytes)
|
||||
value = self.group_readers[group_key].getData(idx, addr, byte)
|
||||
values.append(value)
|
||||
|
||||
values = np.array(values)
|
||||
@@ -767,10 +775,10 @@ class DynamixelMotorsBus:
|
||||
values = [values]
|
||||
|
||||
assert_same_address(self.model_ctrl_table, motor_models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
||||
group = dxl.GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
|
||||
addr, byte = self.model_ctrl_table[motor_models[0]][data_name]
|
||||
group = dxl.GroupSyncWrite(self.port_handler, self.packet_handler, addr, byte)
|
||||
for idx, value in zip(motor_ids, values, strict=True):
|
||||
data = convert_to_bytes(value, bytes, self.mock)
|
||||
data = convert_to_bytes(value, byte, self.mock)
|
||||
group.addParam(idx, data)
|
||||
|
||||
for _ in range(num_retry):
|
||||
@@ -821,17 +829,17 @@ class DynamixelMotorsBus:
|
||||
values = values.tolist()
|
||||
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[model][data_name]
|
||||
addr, byte = self.model_ctrl_table[model][data_name]
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
|
||||
init_group = data_name not in self.group_readers
|
||||
if init_group:
|
||||
self.group_writers[group_key] = dxl.GroupSyncWrite(
|
||||
self.port_handler, self.packet_handler, addr, bytes
|
||||
self.port_handler, self.packet_handler, addr, byte
|
||||
)
|
||||
|
||||
for idx, value in zip(motor_ids, values, strict=True):
|
||||
data = convert_to_bytes(value, bytes, self.mock)
|
||||
data = convert_to_bytes(value, byte, self.mock)
|
||||
if init_group:
|
||||
self.group_writers[group_key].addParam(idx, data)
|
||||
else:
|
||||
|
||||
@@ -148,7 +148,7 @@ def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str
|
||||
return steps
|
||||
|
||||
|
||||
def convert_to_bytes(value, bytes, mock=False):
|
||||
def convert_to_bytes(value, byte, mock=False):
|
||||
if mock:
|
||||
return value
|
||||
|
||||
@@ -156,16 +156,16 @@ def convert_to_bytes(value, bytes, mock=False):
|
||||
|
||||
# Note: No need to convert back into unsigned int, since this byte preprocessing
|
||||
# already handles it for us.
|
||||
if bytes == 1:
|
||||
if byte == 1:
|
||||
data = [
|
||||
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
|
||||
]
|
||||
elif bytes == 2:
|
||||
elif byte == 2:
|
||||
data = [
|
||||
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
|
||||
scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
|
||||
]
|
||||
elif bytes == 4:
|
||||
elif byte == 4:
|
||||
data = [
|
||||
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
|
||||
scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
|
||||
@@ -175,7 +175,7 @@ def convert_to_bytes(value, bytes, mock=False):
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
|
||||
f"{bytes} is provided instead."
|
||||
f"{byte} is provided instead."
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -207,9 +207,9 @@ def assert_same_address(model_ctrl_table, motor_models, data_name):
|
||||
all_addr = []
|
||||
all_bytes = []
|
||||
for model in motor_models:
|
||||
addr, bytes = model_ctrl_table[model][data_name]
|
||||
addr, byte = model_ctrl_table[model][data_name]
|
||||
all_addr.append(addr)
|
||||
all_bytes.append(bytes)
|
||||
all_bytes.append(byte)
|
||||
|
||||
if len(set(all_addr)) != 1:
|
||||
raise NotImplementedError(
|
||||
@@ -557,6 +557,8 @@ class FeetechMotorsBus:
|
||||
# (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution
|
||||
low_factor = (start_pos - values[i]) / resolution
|
||||
upp_factor = (end_pos - values[i]) / resolution
|
||||
else:
|
||||
raise ValueError(f"Unknown calibration mode {calib_mode}")
|
||||
|
||||
if not in_range:
|
||||
# Get first integer between the two bounds
|
||||
@@ -577,10 +579,16 @@ class FeetechMotorsBus:
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
else:
|
||||
raise ValueError(f"Unknown calibration mode {calib_mode}")
|
||||
|
||||
logging.warning(
|
||||
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
|
||||
f"from '{out_of_range_str}' to '{in_range_str}'."
|
||||
"Auto-correct calibration of motor '%s' by shifting value by %s full turns, "
|
||||
"from '%s' to '%s'.",
|
||||
name,
|
||||
abs(factor),
|
||||
out_of_range_str,
|
||||
in_range_str,
|
||||
)
|
||||
|
||||
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
||||
@@ -674,8 +682,8 @@ class FeetechMotorsBus:
|
||||
motor_ids = [motor_ids]
|
||||
|
||||
assert_same_address(self.model_ctrl_table, self.motor_models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
||||
group = scs.GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes)
|
||||
addr, byte = self.model_ctrl_table[motor_models[0]][data_name]
|
||||
group = scs.GroupSyncRead(self.port_handler, self.packet_handler, addr, byte)
|
||||
for idx in motor_ids:
|
||||
group.addParam(idx)
|
||||
|
||||
@@ -692,7 +700,7 @@ class FeetechMotorsBus:
|
||||
|
||||
values = []
|
||||
for idx in motor_ids:
|
||||
value = group.getData(idx, addr, bytes)
|
||||
value = group.getData(idx, addr, byte)
|
||||
values.append(value)
|
||||
|
||||
if return_list:
|
||||
@@ -727,7 +735,7 @@ class FeetechMotorsBus:
|
||||
models.append(model)
|
||||
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[model][data_name]
|
||||
addr, byte = self.model_ctrl_table[model][data_name]
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
|
||||
if data_name not in self.group_readers:
|
||||
@@ -737,7 +745,7 @@ class FeetechMotorsBus:
|
||||
|
||||
# create new group reader
|
||||
self.group_readers[group_key] = scs.GroupSyncRead(
|
||||
self.port_handler, self.packet_handler, addr, bytes
|
||||
self.port_handler, self.packet_handler, addr, byte
|
||||
)
|
||||
for idx in motor_ids:
|
||||
self.group_readers[group_key].addParam(idx)
|
||||
@@ -755,7 +763,7 @@ class FeetechMotorsBus:
|
||||
|
||||
values = []
|
||||
for idx in motor_ids:
|
||||
value = self.group_readers[group_key].getData(idx, addr, bytes)
|
||||
value = self.group_readers[group_key].getData(idx, addr, byte)
|
||||
values.append(value)
|
||||
|
||||
values = np.array(values)
|
||||
@@ -792,10 +800,10 @@ class FeetechMotorsBus:
|
||||
values = [values]
|
||||
|
||||
assert_same_address(self.model_ctrl_table, motor_models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
||||
group = scs.GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
|
||||
addr, byte = self.model_ctrl_table[motor_models[0]][data_name]
|
||||
group = scs.GroupSyncWrite(self.port_handler, self.packet_handler, addr, byte)
|
||||
for idx, value in zip(motor_ids, values, strict=True):
|
||||
data = convert_to_bytes(value, bytes, self.mock)
|
||||
data = convert_to_bytes(value, byte, self.mock)
|
||||
group.addParam(idx, data)
|
||||
|
||||
for _ in range(num_retry):
|
||||
@@ -846,17 +854,17 @@ class FeetechMotorsBus:
|
||||
values = values.tolist()
|
||||
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[model][data_name]
|
||||
addr, byte = self.model_ctrl_table[model][data_name]
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
|
||||
init_group = data_name not in self.group_readers
|
||||
if init_group:
|
||||
self.group_writers[group_key] = scs.GroupSyncWrite(
|
||||
self.port_handler, self.packet_handler, addr, bytes
|
||||
self.port_handler, self.packet_handler, addr, byte
|
||||
)
|
||||
|
||||
for idx, value in zip(motor_ids, values, strict=True):
|
||||
data = convert_to_bytes(value, bytes, self.mock)
|
||||
data = convert_to_bytes(value, byte, self.mock)
|
||||
if init_group:
|
||||
self.group_writers[group_key].addParam(idx, data)
|
||||
else:
|
||||
|
||||
@@ -95,6 +95,8 @@ def move_to_calibrate(
|
||||
while_move_hook=None,
|
||||
):
|
||||
initial_pos = arm.read("Present_Position", motor_name)
|
||||
p_present_pos = None
|
||||
n_present_pos = None
|
||||
|
||||
if positive_first:
|
||||
p_present_pos = move_until_block(
|
||||
@@ -196,7 +198,7 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
|
||||
calib["wrist_flex"] = move_to_calibrate(arm, "wrist_flex")
|
||||
calib["wrist_flex"] = apply_offset(calib["wrist_flex"], offset=80)
|
||||
|
||||
def in_between_move_hook():
|
||||
def in_between_move_hook_elbow():
|
||||
nonlocal arm, calib
|
||||
time.sleep(2)
|
||||
ef_pos = arm.read("Present_Position", "elbow_flex")
|
||||
@@ -207,14 +209,14 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
|
||||
|
||||
print("Calibrate elbow_flex")
|
||||
calib["elbow_flex"] = move_to_calibrate(
|
||||
arm, "elbow_flex", positive_first=False, in_between_move_hook=in_between_move_hook
|
||||
arm, "elbow_flex", positive_first=False, in_between_move_hook=in_between_move_hook_elbow
|
||||
)
|
||||
calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024)
|
||||
|
||||
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex")
|
||||
time.sleep(1)
|
||||
|
||||
def in_between_move_hook():
|
||||
def in_between_move_hook_shoulder():
|
||||
nonlocal arm, calib
|
||||
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"], "elbow_flex")
|
||||
|
||||
@@ -224,7 +226,7 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
|
||||
"shoulder_lift",
|
||||
invert_drive_mode=True,
|
||||
positive_first=False,
|
||||
in_between_move_hook=in_between_move_hook,
|
||||
in_between_move_hook=in_between_move_hook_shoulder,
|
||||
)
|
||||
# add an 30 steps as offset to align with body
|
||||
calib["shoulder_lift"] = apply_offset(calib["shoulder_lift"], offset=1024 - 50)
|
||||
|
||||
@@ -67,14 +67,14 @@ def calibrate_follower_arm(motors_bus, calib_dir_str):
|
||||
return
|
||||
|
||||
if calib_file.exists():
|
||||
with open(calib_file) as f:
|
||||
with open(calib_file, encoding="utf-8") as f:
|
||||
calibration = json.load(f)
|
||||
print(f"[INFO] Loaded calibration from {calib_file}")
|
||||
else:
|
||||
print("[INFO] Calibration file not found. Running manual calibration...")
|
||||
calibration = run_arm_manual_calibration(motors_bus, "lekiwi", "follower_arm", "follower")
|
||||
print(f"[INFO] Calibration complete. Saving to {calib_file}")
|
||||
with open(calib_file, "w") as f:
|
||||
with open(calib_file, "w", encoding="utf-8") as f:
|
||||
json.dump(calibration, f)
|
||||
try:
|
||||
motors_bus.set_calibration(calibration)
|
||||
|
||||
@@ -47,8 +47,10 @@ def ensure_safe_goal_position(
|
||||
if not torch.allclose(goal_pos, safe_goal_pos):
|
||||
logging.warning(
|
||||
"Relative goal position magnitude had to be clamped to be safe.\n"
|
||||
f" requested relative goal position target: {diff}\n"
|
||||
f" clamped relative goal position target: {safe_diff}"
|
||||
" requested relative goal position target: %s\n"
|
||||
" clamped relative goal position target: %s",
|
||||
diff,
|
||||
safe_diff,
|
||||
)
|
||||
|
||||
return safe_goal_pos
|
||||
@@ -245,6 +247,8 @@ class ManipulatorRobot:
|
||||
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
|
||||
elif self.robot_type in ["so100", "moss", "lekiwi"]:
|
||||
from lerobot.common.robot_devices.motors.feetech import TorqueMode
|
||||
else:
|
||||
raise NotImplementedError(f"Robot type {self.robot_type} is not supported")
|
||||
|
||||
# We assume that at connection time, arms are in a rest position, and torque can
|
||||
# be safely disabled to run calibration and/or set robot preset configurations.
|
||||
@@ -302,7 +306,7 @@ class ManipulatorRobot:
|
||||
arm_calib_path = self.calibration_dir / f"{arm_id}.json"
|
||||
|
||||
if arm_calib_path.exists():
|
||||
with open(arm_calib_path) as f:
|
||||
with open(arm_calib_path, encoding="utf-8") as f:
|
||||
calibration = json.load(f)
|
||||
else:
|
||||
# TODO(rcadene): display a warning in __init__ if calibration file not available
|
||||
@@ -322,7 +326,7 @@ class ManipulatorRobot:
|
||||
|
||||
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
||||
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(arm_calib_path, "w") as f:
|
||||
with open(arm_calib_path, "w", encoding="utf-8") as f:
|
||||
json.dump(calibration, f)
|
||||
|
||||
return calibration
|
||||
|
||||
@@ -262,14 +262,14 @@ class MobileManipulator:
|
||||
arm_calib_path = self.calibration_dir / f"{arm_id}.json"
|
||||
|
||||
if arm_calib_path.exists():
|
||||
with open(arm_calib_path) as f:
|
||||
with open(arm_calib_path, encoding="utf-8") as f:
|
||||
calibration = json.load(f)
|
||||
else:
|
||||
print(f"Missing calibration file '{arm_calib_path}'")
|
||||
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
|
||||
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
||||
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(arm_calib_path, "w") as f:
|
||||
with open(arm_calib_path, "w", encoding="utf-8") as f:
|
||||
json.dump(calibration, f)
|
||||
|
||||
return calibration
|
||||
@@ -372,6 +372,7 @@ class MobileManipulator:
|
||||
|
||||
present_speed = self.last_present_speed
|
||||
|
||||
# TODO(Steven): [WARN] Plenty of general exceptions
|
||||
except Exception as e:
|
||||
print(f"[DEBUG] Error decoding video message: {e}")
|
||||
# If decode fails, fall back to old data
|
||||
|
||||
@@ -68,9 +68,9 @@ class TimeBenchmark(ContextDecorator):
|
||||
Block took approximately 10.00 milliseconds
|
||||
"""
|
||||
|
||||
def __init__(self, print=False):
|
||||
def __init__(self, print_time=False):
|
||||
self.local = threading.local()
|
||||
self.print_time = print
|
||||
self.print_time = print_time
|
||||
|
||||
def __enter__(self):
|
||||
self.local.start_time = time.perf_counter()
|
||||
|
||||
@@ -46,7 +46,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b
|
||||
else:
|
||||
# For packages other than "torch", don't attempt the fallback and set as not available
|
||||
package_exists = False
|
||||
logging.debug(f"Detected {pkg_name} version: {package_version}")
|
||||
logging.debug("Detected %s version: %s", {pkg_name}, package_version)
|
||||
if return_version:
|
||||
return package_exists, package_version
|
||||
else:
|
||||
|
||||
@@ -27,6 +27,8 @@ class AverageMeter:
|
||||
def __init__(self, name: str, fmt: str = ":f"):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.val = 0.0
|
||||
self.avg = 0.0
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
|
||||
@@ -69,7 +69,7 @@ def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
||||
case _:
|
||||
device = torch.device(try_device)
|
||||
if log:
|
||||
logging.warning(f"Using custom {try_device} device.")
|
||||
logging.warning("Using custom %s device.", try_device)
|
||||
|
||||
return device
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ class WandBLogger:
|
||||
resume="must" if cfg.resume else None,
|
||||
)
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
logging.info("Track this run --> %s", colored(wandb.run.get_url(), "yellow", attrs=["bold"]))
|
||||
self._wandb = wandb
|
||||
|
||||
def log_policy(self, checkpoint_dir: Path):
|
||||
@@ -108,7 +108,7 @@ class WandBLogger:
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||
'WandB logging of key "%s" was ignored as its type is not handled by this wrapper.', k
|
||||
)
|
||||
continue
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
|
||||
@@ -64,13 +64,14 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
self.pretrained_path = None
|
||||
if not self.device or not is_torch_device_available(self.device):
|
||||
auto_device = auto_select_torch_device()
|
||||
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||
logging.warning("Device '%s' is not available. Switching to '%s'.", self.device, auto_device)
|
||||
self.device = auto_device.type
|
||||
|
||||
# Automatically deactivate AMP if necessary
|
||||
if self.use_amp and not is_amp_available(self.device):
|
||||
logging.warning(
|
||||
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
||||
"Automatic Mixed Precision (amp) is not available on device '%s'. Deactivating AMP.",
|
||||
self.device,
|
||||
)
|
||||
self.use_amp = False
|
||||
|
||||
@@ -78,15 +79,18 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@abc.abstractproperty
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def observation_delta_indices(self) -> list | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractproperty
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def action_delta_indices(self) -> list | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractproperty
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def reward_delta_indices(self) -> list | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -128,7 +132,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
return None
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
||||
with open(save_directory / CONFIG_NAME, "w", encoding="utf-8") as f, draccus.config_type("json"):
|
||||
draccus.dump(self, f, indent=4)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -123,7 +123,10 @@ class TrainPipelineConfig(HubMixin):
|
||||
return draccus.encode(self)
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
||||
with (
|
||||
open(save_directory / TRAIN_CONFIG_NAME, "w", encoding="utf-8") as f,
|
||||
draccus.config_type("json"),
|
||||
):
|
||||
draccus.dump(self, f, indent=4)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -90,6 +90,7 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
||||
print("Scanning all baudrates and motor indices")
|
||||
all_baudrates = set(series_baudrate_table.values())
|
||||
motor_index = -1 # Set the motor index to an out-of-range value.
|
||||
baudrate = None
|
||||
|
||||
for baudrate in all_baudrates:
|
||||
motor_bus.set_bus_baudrate(baudrate)
|
||||
|
||||
@@ -81,6 +81,7 @@ This might require a sudo permission to allow your terminal to monitor keyboard
|
||||
**NOTE**: You can resume/continue data recording by running the same data recording command twice.
|
||||
"""
|
||||
|
||||
# TODO(Steven): This script should be updated to use the new robot API and the new dataset API.
|
||||
import argparse
|
||||
import importlib
|
||||
import logging
|
||||
|
||||
@@ -59,7 +59,7 @@ np_version = np.__version__ if HAS_NP else "N/A"
|
||||
|
||||
torch_version = torch.__version__ if HAS_TORCH else "N/A"
|
||||
torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
|
||||
cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A"
|
||||
cuda_version = torch.version.cuda if HAS_TORCH and torch.version.cuda is not None else "N/A"
|
||||
|
||||
|
||||
# TODO(aliberts): refactor into an actual command `lerobot env`
|
||||
|
||||
@@ -259,6 +259,10 @@ def eval_policy(
|
||||
threads = [] # for video saving threads
|
||||
n_episodes_rendered = 0 # for saving the correct number of videos
|
||||
|
||||
video_paths: list[str] = [] # max_episodes_rendered > 0:
|
||||
ep_frames: list[np.ndarray] = [] # max_episodes_rendered > 0
|
||||
episode_data: dict | None = None # return_episode_data == True
|
||||
|
||||
# Callback for visualization.
|
||||
def render_frame(env: gym.vector.VectorEnv):
|
||||
# noqa: B023
|
||||
@@ -271,19 +275,11 @@ def eval_policy(
|
||||
# Here we must render all frames and discard any we don't need.
|
||||
ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
|
||||
|
||||
if max_episodes_rendered > 0:
|
||||
video_paths: list[str] = []
|
||||
|
||||
if return_episode_data:
|
||||
episode_data: dict | None = None
|
||||
|
||||
# we dont want progress bar when we use slurm, since it clutters the logs
|
||||
progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm())
|
||||
for batch_ix in progbar:
|
||||
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
|
||||
# step.
|
||||
if max_episodes_rendered > 0:
|
||||
ep_frames: list[np.ndarray] = []
|
||||
|
||||
if start_seed is None:
|
||||
seeds = None
|
||||
@@ -320,13 +316,19 @@ def eval_policy(
|
||||
else:
|
||||
all_seeds.append(None)
|
||||
|
||||
# FIXME: episode_data is either None or it doesn't exist
|
||||
if return_episode_data:
|
||||
if episode_data is None:
|
||||
start_data_index = 0
|
||||
elif isinstance(episode_data, dict):
|
||||
start_data_index = episode_data["index"][-1].item() + 1
|
||||
else:
|
||||
start_data_index = 0
|
||||
|
||||
this_episode_data = _compile_episode_data(
|
||||
rollout_data,
|
||||
done_indices,
|
||||
start_episode_index=batch_ix * env.num_envs,
|
||||
start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)),
|
||||
start_data_index=start_data_index,
|
||||
fps=env.unwrapped.metadata["render_fps"],
|
||||
)
|
||||
if episode_data is None:
|
||||
@@ -453,6 +455,7 @@ def _compile_episode_data(
|
||||
return data_dict
|
||||
|
||||
|
||||
# TODO(Steven): [WARN] Redefining built-in 'eval'
|
||||
@parser.wrap()
|
||||
def eval_main(cfg: EvalPipelineConfig):
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
@@ -489,7 +492,7 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
print(info["aggregated"])
|
||||
|
||||
# Save info
|
||||
with open(Path(cfg.output_dir) / "eval_info.json", "w") as f:
|
||||
with open(Path(cfg.output_dir) / "eval_info.json", "w", encoding="utf-8") as f:
|
||||
json.dump(info, f, indent=2)
|
||||
|
||||
env.close()
|
||||
|
||||
@@ -53,6 +53,7 @@ import torch
|
||||
from huggingface_hub import HfApi
|
||||
from safetensors.torch import save_file
|
||||
|
||||
# TODO(Steven): #711 Broke this
|
||||
from lerobot.common.datasets.compute_stats import compute_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
|
||||
@@ -89,7 +90,7 @@ def save_meta_data(
|
||||
|
||||
# save info
|
||||
info_path = meta_data_dir / "info.json"
|
||||
with open(str(info_path), "w") as f:
|
||||
with open(str(info_path), "w", encoding="utf-8") as f:
|
||||
json.dump(info, f, indent=4)
|
||||
|
||||
# save stats
|
||||
@@ -120,11 +121,11 @@ def push_dataset_card_to_hub(
|
||||
repo_id: str,
|
||||
revision: str | None,
|
||||
tags: list | None = None,
|
||||
license: str = "apache-2.0",
|
||||
dataset_license: str = "apache-2.0",
|
||||
**card_kwargs,
|
||||
):
|
||||
"""Creates and pushes a LeRobotDataset Card with appropriate tags to easily find it on the hub."""
|
||||
card = create_lerobot_dataset_card(tags=tags, license=license, **card_kwargs)
|
||||
card = create_lerobot_dataset_card(tags=tags, license=dataset_license, **card_kwargs)
|
||||
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=revision)
|
||||
|
||||
|
||||
@@ -213,6 +214,7 @@ def push_dataset_to_hub(
|
||||
encoding,
|
||||
)
|
||||
|
||||
# TODO(Steven): This doesn't seem to exist, maybe it was removed/changed recently?
|
||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||
repo_id=repo_id,
|
||||
hf_dataset=hf_dataset,
|
||||
|
||||
@@ -155,12 +155,14 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||
if cfg.env is not None:
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
||||
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
|
||||
logging.info(f"{dataset.num_episodes=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
logging.info("cfg.env.task=%s", cfg.env.task)
|
||||
logging.info("cfg.steps=%s (%s)", cfg.steps, format_big_number(cfg.steps))
|
||||
logging.info("dataset.num_frames=%s (%s)", dataset.num_frames, format_big_number(dataset.num_frames))
|
||||
logging.info("dataset.num_episodes=%s", dataset.num_episodes)
|
||||
logging.info(
|
||||
"num_learnable_params=%s (%s)", num_learnable_params, format_big_number(num_learnable_params)
|
||||
)
|
||||
logging.info("num_total_params=%s (%s)", num_total_params, format_big_number(num_total_params))
|
||||
|
||||
# create dataloader for offline training
|
||||
if hasattr(cfg.policy, "drop_n_last_frames"):
|
||||
@@ -238,7 +240,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
train_tracker.reset_averages()
|
||||
|
||||
if cfg.save_checkpoint and is_saving_step:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
logging.info("Checkpoint policy after step %s", step)
|
||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
||||
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
@@ -247,7 +249,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
if cfg.env and is_eval_step:
|
||||
step_id = get_step_identifier(step, cfg.steps)
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
logging.info("Eval policy at step %s", step)
|
||||
with (
|
||||
torch.no_grad(),
|
||||
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
||||
|
||||
@@ -150,7 +150,7 @@ def run_server(
|
||||
400,
|
||||
)
|
||||
dataset_version = (
|
||||
str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
|
||||
str(dataset.meta.version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
|
||||
)
|
||||
match = re.search(r"v(\d+)\.", dataset_version)
|
||||
if match:
|
||||
@@ -358,7 +358,7 @@ def visualize_dataset_html(
|
||||
if force_override:
|
||||
shutil.rmtree(output_dir)
|
||||
else:
|
||||
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
|
||||
logging.info("Output directory already exists. Loading from it: '%s'", {output_dir})
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
68
tests/fixtures/dataset_factories.py
vendored
68
tests/fixtures/dataset_factories.py
vendored
@@ -52,16 +52,16 @@ def get_task_index(task_dicts: dict, task: str) -> int:
|
||||
return task_to_task_index[task]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_tensor_factory():
|
||||
@pytest.fixture(name="img_tensor_factory", scope="session")
|
||||
def fixture_img_tensor_factory():
|
||||
def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor:
|
||||
return torch.rand((channels, height, width), dtype=dtype)
|
||||
|
||||
return _create_img_tensor
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_array_factory():
|
||||
@pytest.fixture(name="img_array_factory", scope="session")
|
||||
def fixture_img_array_factory():
|
||||
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray:
|
||||
if np.issubdtype(dtype, np.unsignedinteger):
|
||||
# Int array in [0, 255] range
|
||||
@@ -76,8 +76,8 @@ def img_array_factory():
|
||||
return _create_img_array
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_factory(img_array_factory):
|
||||
@pytest.fixture(name="img_factory", scope="session")
|
||||
def fixture_img_factory(img_array_factory):
|
||||
def _create_img(height=100, width=100) -> PIL.Image.Image:
|
||||
img_array = img_array_factory(height=height, width=width)
|
||||
return PIL.Image.fromarray(img_array)
|
||||
@@ -85,13 +85,17 @@ def img_factory(img_array_factory):
|
||||
return _create_img
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def features_factory():
|
||||
@pytest.fixture(name="features_factory", scope="session")
|
||||
def fixture_features_factory():
|
||||
def _create_features(
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
motor_features: dict | None = None,
|
||||
camera_features: dict | None = None,
|
||||
use_videos: bool = True,
|
||||
) -> dict:
|
||||
if motor_features is None:
|
||||
motor_features = DUMMY_MOTOR_FEATURES
|
||||
if camera_features is None:
|
||||
camera_features = DUMMY_CAMERA_FEATURES
|
||||
if use_videos:
|
||||
camera_ft = {
|
||||
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items()
|
||||
@@ -107,8 +111,8 @@ def features_factory():
|
||||
return _create_features
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def info_factory(features_factory):
|
||||
@pytest.fixture(name="info_factory", scope="session")
|
||||
def fixture_info_factory(features_factory):
|
||||
def _create_info(
|
||||
codebase_version: str = CODEBASE_VERSION,
|
||||
fps: int = DEFAULT_FPS,
|
||||
@@ -121,10 +125,14 @@ def info_factory(features_factory):
|
||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||
data_path: str = DEFAULT_PARQUET_PATH,
|
||||
video_path: str = DEFAULT_VIDEO_PATH,
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
motor_features: dict | None = None,
|
||||
camera_features: dict | None = None,
|
||||
use_videos: bool = True,
|
||||
) -> dict:
|
||||
if motor_features is None:
|
||||
motor_features = DUMMY_MOTOR_FEATURES
|
||||
if camera_features is None:
|
||||
camera_features = DUMMY_CAMERA_FEATURES
|
||||
features = features_factory(motor_features, camera_features, use_videos)
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
@@ -145,8 +153,8 @@ def info_factory(features_factory):
|
||||
return _create_info
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def stats_factory():
|
||||
@pytest.fixture(name="stats_factory", scope="session")
|
||||
def fixture_stats_factory():
|
||||
def _create_stats(
|
||||
features: dict[str] | None = None,
|
||||
) -> dict:
|
||||
@@ -175,8 +183,8 @@ def stats_factory():
|
||||
return _create_stats
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episodes_stats_factory(stats_factory):
|
||||
@pytest.fixture(name="episodes_stats_factory", scope="session")
|
||||
def fixture_episodes_stats_factory(stats_factory):
|
||||
def _create_episodes_stats(
|
||||
features: dict[str],
|
||||
total_episodes: int = 3,
|
||||
@@ -192,8 +200,8 @@ def episodes_stats_factory(stats_factory):
|
||||
return _create_episodes_stats
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tasks_factory():
|
||||
@pytest.fixture(name="tasks_factory", scope="session")
|
||||
def fixture_tasks_factory():
|
||||
def _create_tasks(total_tasks: int = 3) -> int:
|
||||
tasks = {}
|
||||
for task_index in range(total_tasks):
|
||||
@@ -204,8 +212,8 @@ def tasks_factory():
|
||||
return _create_tasks
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episodes_factory(tasks_factory):
|
||||
@pytest.fixture(name="episodes_factory", scope="session")
|
||||
def fixture_episodes_factory(tasks_factory):
|
||||
def _create_episodes(
|
||||
total_episodes: int = 3,
|
||||
total_frames: int = 400,
|
||||
@@ -252,8 +260,8 @@ def episodes_factory(tasks_factory):
|
||||
return _create_episodes
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
|
||||
@pytest.fixture(name="hf_dataset_factory", scope="session")
|
||||
def fixture_hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
|
||||
def _create_hf_dataset(
|
||||
features: dict | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
@@ -310,8 +318,8 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
||||
return _create_hf_dataset
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def lerobot_dataset_metadata_factory(
|
||||
@pytest.fixture(name="lerobot_dataset_metadata_factory", scope="session")
|
||||
def fixture_lerobot_dataset_metadata_factory(
|
||||
info_factory,
|
||||
stats_factory,
|
||||
episodes_stats_factory,
|
||||
@@ -364,8 +372,8 @@ def lerobot_dataset_metadata_factory(
|
||||
return _create_lerobot_dataset_metadata
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def lerobot_dataset_factory(
|
||||
@pytest.fixture(name="lerobot_dataset_factory", scope="session")
|
||||
def fixture_lerobot_dataset_factory(
|
||||
info_factory,
|
||||
stats_factory,
|
||||
episodes_stats_factory,
|
||||
@@ -443,6 +451,6 @@ def lerobot_dataset_factory(
|
||||
return _create_lerobot_dataset
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def empty_lerobot_dataset_factory() -> LeRobotDatasetFactory:
|
||||
@pytest.fixture(name="empty_lerobot_dataset_factory", scope="session")
|
||||
def fixture_empty_lerobot_dataset_factory() -> LeRobotDatasetFactory:
|
||||
return partial(LeRobotDataset.create, repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS)
|
||||
|
||||
34
tests/fixtures/files.py
vendored
34
tests/fixtures/files.py
vendored
@@ -31,12 +31,12 @@ from lerobot.common.datasets.utils import (
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def info_path(info_factory):
|
||||
def _create_info_json_file(dir: Path, info: dict | None = None) -> Path:
|
||||
def _create_info_json_file(input_dir: Path, info: dict | None = None) -> Path:
|
||||
if not info:
|
||||
info = info_factory()
|
||||
fpath = dir / INFO_PATH
|
||||
fpath = input_dir / INFO_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(fpath, "w") as f:
|
||||
with open(fpath, "w", encoding="utf-8") as f:
|
||||
json.dump(info, f, indent=4, ensure_ascii=False)
|
||||
return fpath
|
||||
|
||||
@@ -45,12 +45,12 @@ def info_path(info_factory):
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def stats_path(stats_factory):
|
||||
def _create_stats_json_file(dir: Path, stats: dict | None = None) -> Path:
|
||||
def _create_stats_json_file(input_dir: Path, stats: dict | None = None) -> Path:
|
||||
if not stats:
|
||||
stats = stats_factory()
|
||||
fpath = dir / STATS_PATH
|
||||
fpath = input_dir / STATS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(fpath, "w") as f:
|
||||
with open(fpath, "w", encoding="utf-8") as f:
|
||||
json.dump(stats, f, indent=4, ensure_ascii=False)
|
||||
return fpath
|
||||
|
||||
@@ -59,10 +59,10 @@ def stats_path(stats_factory):
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episodes_stats_path(episodes_stats_factory):
|
||||
def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path:
|
||||
def _create_episodes_stats_jsonl_file(input_dir: Path, episodes_stats: list[dict] | None = None) -> Path:
|
||||
if not episodes_stats:
|
||||
episodes_stats = episodes_stats_factory()
|
||||
fpath = dir / EPISODES_STATS_PATH
|
||||
fpath = input_dir / EPISODES_STATS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(episodes_stats.values())
|
||||
@@ -73,10 +73,10 @@ def episodes_stats_path(episodes_stats_factory):
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tasks_path(tasks_factory):
|
||||
def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path:
|
||||
def _create_tasks_jsonl_file(input_dir: Path, tasks: list | None = None) -> Path:
|
||||
if not tasks:
|
||||
tasks = tasks_factory()
|
||||
fpath = dir / TASKS_PATH
|
||||
fpath = input_dir / TASKS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(tasks.values())
|
||||
@@ -87,10 +87,10 @@ def tasks_path(tasks_factory):
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episode_path(episodes_factory):
|
||||
def _create_episodes_jsonl_file(dir: Path, episodes: list | None = None) -> Path:
|
||||
def _create_episodes_jsonl_file(input_dir: Path, episodes: list | None = None) -> Path:
|
||||
if not episodes:
|
||||
episodes = episodes_factory()
|
||||
fpath = dir / EPISODES_PATH
|
||||
fpath = input_dir / EPISODES_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(episodes.values())
|
||||
@@ -102,7 +102,7 @@ def episode_path(episodes_factory):
|
||||
@pytest.fixture(scope="session")
|
||||
def single_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||
def _create_single_episode_parquet(
|
||||
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||
input_dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||
) -> Path:
|
||||
if not info:
|
||||
info = info_factory()
|
||||
@@ -112,7 +112,7 @@ def single_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||
data_path = info["data_path"]
|
||||
chunks_size = info["chunks_size"]
|
||||
ep_chunk = ep_idx // chunks_size
|
||||
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
fpath = input_dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
table = hf_dataset.data.table
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
@@ -125,7 +125,7 @@ def single_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||
@pytest.fixture(scope="session")
|
||||
def multi_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||
def _create_multi_episode_parquet(
|
||||
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||
input_dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||
) -> Path:
|
||||
if not info:
|
||||
info = info_factory()
|
||||
@@ -137,11 +137,11 @@ def multi_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||
total_episodes = info["total_episodes"]
|
||||
for ep_idx in range(total_episodes):
|
||||
ep_chunk = ep_idx // chunks_size
|
||||
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
fpath = input_dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
table = hf_dataset.data.table
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
pq.write_table(ep_table, fpath)
|
||||
return dir / "data"
|
||||
return input_dir / "data"
|
||||
|
||||
return _create_multi_episode_parquet
|
||||
|
||||
6
tests/fixtures/hub.py
vendored
6
tests/fixtures/hub.py
vendored
@@ -81,12 +81,12 @@ def mock_snapshot_download_factory(
|
||||
return None
|
||||
|
||||
def _mock_snapshot_download(
|
||||
repo_id: str,
|
||||
_repo_id: str,
|
||||
*_args,
|
||||
local_dir: str | Path | None = None,
|
||||
allow_patterns: str | list[str] | None = None,
|
||||
ignore_patterns: str | list[str] | None = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
**_kwargs,
|
||||
) -> str:
|
||||
if not local_dir:
|
||||
local_dir = LEROBOT_TEST_DIR
|
||||
|
||||
12
tests/fixtures/optimizers.py
vendored
12
tests/fixtures/optimizers.py
vendored
@@ -18,13 +18,13 @@ from lerobot.common.optim.optimizers import AdamConfig
|
||||
from lerobot.common.optim.schedulers import VQBeTSchedulerConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_params():
|
||||
@pytest.fixture(name="model_params")
|
||||
def fixture_model_params():
|
||||
return [torch.nn.Parameter(torch.randn(10, 10))]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def optimizer(model_params):
|
||||
@pytest.fixture(name="optimizer")
|
||||
def fixture_optimizer(model_params):
|
||||
optimizer = AdamConfig().build(model_params)
|
||||
# Dummy step to populate state
|
||||
loss = sum(param.sum() for param in model_params)
|
||||
@@ -33,7 +33,7 @@ def optimizer(model_params):
|
||||
return optimizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scheduler(optimizer):
|
||||
@pytest.fixture(name="scheduler")
|
||||
def fixture_scheduler(optimizer):
|
||||
config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5)
|
||||
return config.build(optimizer, num_training_steps=100)
|
||||
|
||||
@@ -22,6 +22,8 @@ from lerobot.common.datasets.utils import create_lerobot_dataset_card
|
||||
def test_default_parameters():
|
||||
card = create_lerobot_dataset_card()
|
||||
assert isinstance(card, DatasetCard)
|
||||
# TODO(Steven): Base class CardDate should have 'tags' as a member if we want RepoCard to hold a reference to this abstraction
|
||||
# card.data gives a CardDate type, implementations of this class do have 'tags' but the base class doesn't
|
||||
assert card.data.tags == ["LeRobot"]
|
||||
assert card.data.task_categories == ["robotics"]
|
||||
assert card.data.configs == [
|
||||
|
||||
@@ -57,7 +57,7 @@ def rotate(color_image, rotation):
|
||||
|
||||
|
||||
class VideoCapture:
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
self._mock_dict = {
|
||||
CAP_PROP_FPS: 30,
|
||||
CAP_PROP_FRAME_WIDTH: 640,
|
||||
|
||||
@@ -24,10 +24,9 @@ DEFAULT_BAUDRATE = 9_600
|
||||
COMM_SUCCESS = 0 # tx or rx packet communication success
|
||||
|
||||
|
||||
def convert_to_bytes(value, bytes):
|
||||
def convert_to_bytes(value, _byte):
|
||||
# TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform
|
||||
# `convert_bytes_to_value`
|
||||
del bytes # unused
|
||||
return value
|
||||
|
||||
|
||||
@@ -74,7 +73,7 @@ class PacketHandler:
|
||||
|
||||
|
||||
class GroupSyncRead:
|
||||
def __init__(self, port_handler, packet_handler, address, bytes):
|
||||
def __init__(self, _port_handler, packet_handler, _address, _byte):
|
||||
self.packet_handler = packet_handler
|
||||
|
||||
def addParam(self, motor_index): # noqa: N802
|
||||
@@ -85,12 +84,12 @@ class GroupSyncRead:
|
||||
def txRxPacket(self): # noqa: N802
|
||||
return COMM_SUCCESS
|
||||
|
||||
def getData(self, index, address, bytes): # noqa: N802
|
||||
def getData(self, index, address, _byte): # noqa: N802
|
||||
return self.packet_handler.data[index][address]
|
||||
|
||||
|
||||
class GroupSyncWrite:
|
||||
def __init__(self, port_handler, packet_handler, address, bytes):
|
||||
def __init__(self, _port_handler, packet_handler, address, _byte):
|
||||
self.packet_handler = packet_handler
|
||||
self.address = address
|
||||
|
||||
|
||||
@@ -27,6 +27,13 @@ class format(enum.Enum): # noqa: N801
|
||||
|
||||
|
||||
class config: # noqa: N801
|
||||
device_enabled = None
|
||||
stream_type = None
|
||||
width = None
|
||||
height = None
|
||||
color_format = None
|
||||
fps = None
|
||||
|
||||
def enable_device(self, device_id: str):
|
||||
self.device_enabled = device_id
|
||||
|
||||
@@ -125,8 +132,7 @@ class RSDevice:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_info(self, camera_info) -> str:
|
||||
del camera_info # unused
|
||||
def get_info(self, _camera_info) -> str:
|
||||
# return fake serial number
|
||||
return "123456789"
|
||||
|
||||
@@ -145,4 +151,3 @@ class camera_info: # noqa: N801
|
||||
|
||||
def __init__(self, serial_number):
|
||||
del serial_number
|
||||
pass
|
||||
|
||||
@@ -24,10 +24,10 @@ DEFAULT_BAUDRATE = 1_000_000
|
||||
COMM_SUCCESS = 0 # tx or rx packet communication success
|
||||
|
||||
|
||||
def convert_to_bytes(value, bytes):
|
||||
def convert_to_bytes(value, byte):
|
||||
# TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform
|
||||
# `convert_bytes_to_value`
|
||||
del bytes # unused
|
||||
del byte # unused
|
||||
return value
|
||||
|
||||
|
||||
@@ -85,7 +85,7 @@ class PacketHandler:
|
||||
|
||||
|
||||
class GroupSyncRead:
|
||||
def __init__(self, port_handler, packet_handler, address, bytes):
|
||||
def __init__(self, _port_handler, packet_handler, _address, _byte):
|
||||
self.packet_handler = packet_handler
|
||||
|
||||
def addParam(self, motor_index): # noqa: N802
|
||||
@@ -96,12 +96,12 @@ class GroupSyncRead:
|
||||
def txRxPacket(self): # noqa: N802
|
||||
return COMM_SUCCESS
|
||||
|
||||
def getData(self, index, address, bytes): # noqa: N802
|
||||
def getData(self, index, address, _byte): # noqa: N802
|
||||
return self.packet_handler.data[index][address]
|
||||
|
||||
|
||||
class GroupSyncWrite:
|
||||
def __init__(self, port_handler, packet_handler, address, bytes):
|
||||
def __init__(self, _port_handler, packet_handler, address, _byte):
|
||||
self.packet_handler = packet_handler
|
||||
self.address = address
|
||||
|
||||
|
||||
@@ -81,11 +81,11 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for dataset in [
|
||||
for available_dataset in [
|
||||
"lerobot/pusht",
|
||||
"lerobot/aloha_sim_insertion_human",
|
||||
"lerobot/xarm_lift_medium",
|
||||
"lerobot/nyu_franka_play_dataset",
|
||||
"lerobot/cmu_stretch",
|
||||
]:
|
||||
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors", repo_id=dataset)
|
||||
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors", repo_id=available_dataset)
|
||||
|
||||
@@ -51,7 +51,7 @@ def save_single_transforms(original_frame: torch.Tensor, output_dir: Path):
|
||||
}
|
||||
|
||||
frames = {"original_frame": original_frame}
|
||||
for tf_type, tf_name, min_max_values in transforms.items():
|
||||
for tf_type, tf_name, min_max_values in transforms:
|
||||
for min_max in min_max_values:
|
||||
tf_cfg = ImageTransformConfig(type=tf_type, kwargs={tf_name: min_max})
|
||||
tf = make_transform_from_config(tf_cfg)
|
||||
|
||||
@@ -150,6 +150,7 @@ def test_camera(request, camera_type, mock):
|
||||
else:
|
||||
import cv2
|
||||
|
||||
manual_rot_img: np.ndarray = None
|
||||
if rotation is None:
|
||||
manual_rot_img = ori_color_image
|
||||
assert camera.rotation is None
|
||||
@@ -197,10 +198,14 @@ def test_camera(request, camera_type, mock):
|
||||
@require_camera
|
||||
def test_save_images_from_cameras(tmp_path, request, camera_type, mock):
|
||||
# TODO(rcadene): refactor
|
||||
save_images_from_cameras = None
|
||||
|
||||
if camera_type == "opencv":
|
||||
from lerobot.common.robot_devices.cameras.opencv import save_images_from_cameras
|
||||
elif camera_type == "intelrealsense":
|
||||
from lerobot.common.robot_devices.cameras.intelrealsense import save_images_from_cameras
|
||||
else:
|
||||
raise ValueError(f"Unsupported camera type: {camera_type}")
|
||||
|
||||
# Small `record_time_s` to speedup unit tests
|
||||
save_images_from_cameras(tmp_path, record_time_s=0.02, mock=mock)
|
||||
|
||||
@@ -30,12 +30,12 @@ from lerobot.common.datasets.compute_stats import (
|
||||
)
|
||||
|
||||
|
||||
def mock_load_image_as_numpy(path, dtype, channel_first):
|
||||
def mock_load_image_as_numpy(_path, dtype, channel_first):
|
||||
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_array():
|
||||
@pytest.fixture(name="sample_array")
|
||||
def fixture_sample_array():
|
||||
return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ def test_sample_indices():
|
||||
|
||||
|
||||
@patch("lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy)
|
||||
def test_sample_images(mock_load):
|
||||
def test_sample_images(_mock_load):
|
||||
image_paths = [f"image_{i}.jpg" for i in range(100)]
|
||||
images = sample_images(image_paths)
|
||||
assert isinstance(images, np.ndarray)
|
||||
|
||||
@@ -48,8 +48,8 @@ from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.utils import require_x86_64_kernel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def image_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||
@pytest.fixture(name="image_dataset")
|
||||
def fixture_image_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {
|
||||
"image": {
|
||||
"dtype": "image",
|
||||
@@ -374,7 +374,7 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
if required:
|
||||
assert key in item, f"{key}"
|
||||
else:
|
||||
logging.warning(f'Missing key in dataset: "{key}" not in {dataset}.')
|
||||
logging.warning('Missing key in dataset: "%s" not in %s.', key, dataset)
|
||||
continue
|
||||
|
||||
if delta_timestamps is not None and key in delta_timestamps:
|
||||
|
||||
@@ -42,7 +42,9 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.n
|
||||
table = hf_dataset.data.table
|
||||
total_episodes = calculate_total_episode(hf_dataset)
|
||||
for ep_idx in range(total_episodes):
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
ep_table = table.filter(
|
||||
pc.equal(table["episode_index"], ep_idx)
|
||||
) # TODO(Steven): What is this check supposed to do?
|
||||
episode_lengths.insert(ep_idx, len(ep_table))
|
||||
|
||||
cumulative_lengths = list(accumulate(episode_lengths))
|
||||
@@ -52,8 +54,8 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.n
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def synced_timestamps_factory(hf_dataset_factory):
|
||||
@pytest.fixture(name="synced_timestamps_factory", scope="module")
|
||||
def fixture_synced_timestamps_factory(hf_dataset_factory):
|
||||
def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
hf_dataset = hf_dataset_factory(fps=fps)
|
||||
timestamps = torch.stack(hf_dataset["timestamp"]).numpy()
|
||||
@@ -64,8 +66,8 @@ def synced_timestamps_factory(hf_dataset_factory):
|
||||
return _create_synced_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def unsynced_timestamps_factory(synced_timestamps_factory):
|
||||
@pytest.fixture(name="unsynced_timestamps_factory", scope="module")
|
||||
def fixture_unsynced_timestamps_factory(synced_timestamps_factory):
|
||||
def _create_unsynced_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
@@ -76,8 +78,8 @@ def unsynced_timestamps_factory(synced_timestamps_factory):
|
||||
return _create_unsynced_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def slightly_off_timestamps_factory(synced_timestamps_factory):
|
||||
@pytest.fixture(name="slightly_off_timestamps_factory", scope="module")
|
||||
def fixture_slightly_off_timestamps_factory(synced_timestamps_factory):
|
||||
def _create_slightly_off_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
@@ -88,22 +90,26 @@ def slightly_off_timestamps_factory(synced_timestamps_factory):
|
||||
return _create_slightly_off_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def valid_delta_timestamps_factory():
|
||||
@pytest.fixture(name="valid_delta_timestamps_factory", scope="module")
|
||||
def fixture_valid_delta_timestamps_factory():
|
||||
def _create_valid_delta_timestamps(
|
||||
fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)
|
||||
fps: int = 30, keys: list | None = None, min_max_range: tuple[int, int] = (-10, 10)
|
||||
) -> dict:
|
||||
if keys is None:
|
||||
keys = DUMMY_MOTOR_FEATURES
|
||||
delta_timestamps = {key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys}
|
||||
return delta_timestamps
|
||||
|
||||
return _create_valid_delta_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
@pytest.fixture(name="invalid_delta_timestamps_factory", scope="module")
|
||||
def fixture_invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
def _create_invalid_delta_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None
|
||||
) -> dict:
|
||||
if keys is None:
|
||||
keys = DUMMY_MOTOR_FEATURES
|
||||
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
||||
# Modify a single timestamp just outside tolerance
|
||||
for key in keys:
|
||||
@@ -113,11 +119,13 @@ def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
return _create_invalid_delta_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
@pytest.fixture(name="slightly_off_delta_timestamps_factory", scope="module")
|
||||
def fixture_slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
def _create_slightly_off_delta_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None
|
||||
) -> dict:
|
||||
if keys is None:
|
||||
keys = DUMMY_MOTOR_FEATURES
|
||||
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
||||
# Modify a single timestamp just inside tolerance
|
||||
for key in delta_timestamps:
|
||||
@@ -128,9 +136,11 @@ def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
return _create_slightly_off_delta_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def delta_indices_factory():
|
||||
def _delta_indices(keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)) -> dict:
|
||||
@pytest.fixture(name="delta_indices_factory", scope="module")
|
||||
def fixture_delta_indices_factory():
|
||||
def _delta_indices(keys: list | None = None, min_max_range: tuple[int, int] = (-10, 10)) -> dict:
|
||||
if keys is None:
|
||||
keys = DUMMY_MOTOR_FEATURES
|
||||
return {key: list(range(*min_max_range)) for key in keys}
|
||||
|
||||
return _delta_indices
|
||||
|
||||
@@ -38,7 +38,7 @@ def _run_script(path):
|
||||
|
||||
|
||||
def _read_file(path):
|
||||
with open(path) as file:
|
||||
with open(path, encoding="utf-8") as file:
|
||||
return file.read()
|
||||
|
||||
|
||||
|
||||
@@ -37,8 +37,8 @@ from tests.scripts.save_image_transforms_to_safetensors import ARTIFACT_DIR
|
||||
from tests.utils import require_x86_64_kernel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def color_jitters():
|
||||
@pytest.fixture(name="color_jitters")
|
||||
def fixture_color_jitters():
|
||||
return [
|
||||
v2.ColorJitter(brightness=0.5),
|
||||
v2.ColorJitter(contrast=0.5),
|
||||
@@ -46,18 +46,18 @@ def color_jitters():
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def single_transforms():
|
||||
@pytest.fixture(name="single_transforms")
|
||||
def fixture_single_transforms():
|
||||
return load_file(ARTIFACT_DIR / "single_transforms.safetensors")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def img_tensor(single_transforms):
|
||||
@pytest.fixture(name="img_tensor")
|
||||
def fixture_img_tensor(single_transforms):
|
||||
return single_transforms["original_frame"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_transforms():
|
||||
@pytest.fixture(name="default_transforms")
|
||||
def fixture_default_transforms():
|
||||
return load_file(ARTIFACT_DIR / "default_transforms.safetensors")
|
||||
|
||||
|
||||
|
||||
@@ -20,8 +20,8 @@ import pytest
|
||||
from lerobot.common.utils.io_utils import deserialize_json_into_object
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_json_file(tmp_path: Path):
|
||||
@pytest.fixture(name="tmp_json_file")
|
||||
def fixture_tmp_json_file(tmp_path: Path):
|
||||
"""Writes `data` to a temporary JSON file and returns the file's path."""
|
||||
|
||||
def _write(data: Any) -> Path:
|
||||
|
||||
@@ -16,8 +16,8 @@ import pytest
|
||||
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_metrics():
|
||||
@pytest.fixture(name="mock_metrics")
|
||||
def fixture_mock_metrics():
|
||||
return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")}
|
||||
|
||||
|
||||
@@ -87,6 +87,7 @@ def test_metrics_tracker_getattr(mock_metrics):
|
||||
_ = tracker.non_existent_metric
|
||||
|
||||
|
||||
# TODO(Steven): I don't understand what's supposed to happen here
|
||||
def test_metrics_tracker_setattr(mock_metrics):
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
||||
tracker.loss = 2.0
|
||||
|
||||
@@ -74,7 +74,7 @@ def test_non_mutate():
|
||||
def test_index_error_no_data():
|
||||
buffer, _ = make_new_buffer()
|
||||
with pytest.raises(IndexError):
|
||||
buffer[0]
|
||||
_ = buffer[0]
|
||||
|
||||
|
||||
def test_index_error_with_data():
|
||||
@@ -83,9 +83,9 @@ def test_index_error_with_data():
|
||||
new_data = make_spoof_data_frames(1, n_frames)
|
||||
buffer.add_data(new_data)
|
||||
with pytest.raises(IndexError):
|
||||
buffer[n_frames]
|
||||
_ = buffer[n_frames]
|
||||
with pytest.raises(IndexError):
|
||||
buffer[-n_frames - 1]
|
||||
_ = buffer[-n_frames - 1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("do_reload", [False, True])
|
||||
@@ -185,7 +185,7 @@ def test_delta_timestamps_outside_tolerance_inside_episode_range():
|
||||
buffer.add_data(new_data)
|
||||
buffer.tolerance_s = 0.04
|
||||
with pytest.raises(AssertionError):
|
||||
buffer[2]
|
||||
_ = buffer[2]
|
||||
|
||||
|
||||
def test_delta_timestamps_outside_tolerance_outside_episode_range():
|
||||
@@ -229,6 +229,7 @@ def test_compute_sampler_weights_trivial(
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
|
||||
)
|
||||
expected_weights: torch.Tensor = None
|
||||
if offline_dataset_size == 0 or online_dataset_size == 0:
|
||||
expected_weights = torch.ones(offline_dataset_size + online_dataset_size)
|
||||
elif online_sampling_ratio == 0:
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# pylint: disable=redefined-outer-name, unused-argument
|
||||
import inspect
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# pylint: disable=redefined-outer-name, unused-argument
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -32,16 +32,16 @@ from lerobot.common.utils.import_utils import is_package_available
|
||||
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu"
|
||||
|
||||
TEST_ROBOT_TYPES = []
|
||||
for robot_type in available_robots:
|
||||
TEST_ROBOT_TYPES += [(robot_type, True), (robot_type, False)]
|
||||
for available_robot_type in available_robots:
|
||||
TEST_ROBOT_TYPES += [(available_robot_type, True), (available_robot_type, False)]
|
||||
|
||||
TEST_CAMERA_TYPES = []
|
||||
for camera_type in available_cameras:
|
||||
TEST_CAMERA_TYPES += [(camera_type, True), (camera_type, False)]
|
||||
for available_camera_type in available_cameras:
|
||||
TEST_CAMERA_TYPES += [(available_camera_type, True), (available_camera_type, False)]
|
||||
|
||||
TEST_MOTOR_TYPES = []
|
||||
for motor_type in available_motors:
|
||||
TEST_MOTOR_TYPES += [(motor_type, True), (motor_type, False)]
|
||||
for available_motor_type in available_motors:
|
||||
TEST_MOTOR_TYPES += [(available_motor_type, True), (available_motor_type, False)]
|
||||
|
||||
# Camera indices used for connecting physical cameras
|
||||
OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0))
|
||||
@@ -72,7 +72,6 @@ def require_x86_64_kernel(func):
|
||||
"""
|
||||
Decorator that skips the test if plateform device is not an x86_64 cpu.
|
||||
"""
|
||||
from functools import wraps
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
@@ -87,7 +86,6 @@ def require_cpu(func):
|
||||
"""
|
||||
Decorator that skips the test if device is not cpu.
|
||||
"""
|
||||
from functools import wraps
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
@@ -102,7 +100,6 @@ def require_cuda(func):
|
||||
"""
|
||||
Decorator that skips the test if cuda is not available.
|
||||
"""
|
||||
from functools import wraps
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
@@ -288,17 +285,17 @@ def mock_calibration_dir(calibration_dir):
|
||||
"motor_names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||
}
|
||||
Path(str(calibration_dir)).mkdir(parents=True, exist_ok=True)
|
||||
with open(calibration_dir / "main_follower.json", "w") as f:
|
||||
with open(calibration_dir / "main_follower.json", "w", encoding="utf-8") as f:
|
||||
json.dump(example_calib, f)
|
||||
with open(calibration_dir / "main_leader.json", "w") as f:
|
||||
with open(calibration_dir / "main_leader.json", "w", encoding="utf-8") as f:
|
||||
json.dump(example_calib, f)
|
||||
with open(calibration_dir / "left_follower.json", "w") as f:
|
||||
with open(calibration_dir / "left_follower.json", "w", encoding="utf-8") as f:
|
||||
json.dump(example_calib, f)
|
||||
with open(calibration_dir / "left_leader.json", "w") as f:
|
||||
with open(calibration_dir / "left_leader.json", "w", encoding="utf-8") as f:
|
||||
json.dump(example_calib, f)
|
||||
with open(calibration_dir / "right_follower.json", "w") as f:
|
||||
with open(calibration_dir / "right_follower.json", "w", encoding="utf-8") as f:
|
||||
json.dump(example_calib, f)
|
||||
with open(calibration_dir / "right_leader.json", "w") as f:
|
||||
with open(calibration_dir / "right_leader.json", "w", encoding="utf-8") as f:
|
||||
json.dump(example_calib, f)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user