[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-24 13:41:27 +00:00
committed by AdilZouitine
parent 2945bbb221
commit 7c05755823
123 changed files with 1161 additions and 3425 deletions

View File

@@ -46,18 +46,14 @@ def sample_indices(data_len: int) -> list[int]:
return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist()
def auto_downsample_height_width(
img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300
):
def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300):
_, height, width = img.shape
if max(width, height) < max_size_threshold:
# no downsampling needed
return img
downsample_factor = (
int(width / target_size) if width > height else int(height / target_size)
)
downsample_factor = int(width / target_size) if width > height else int(height / target_size)
return img[:, ::downsample_factor, ::downsample_factor]
@@ -79,9 +75,7 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
return images
def get_feature_stats(
array: np.ndarray, axis: tuple, keepdims: bool
) -> dict[str, np.ndarray]:
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
return {
"min": np.min(array, axis=axis, keepdims=keepdims),
"max": np.max(array, axis=axis, keepdims=keepdims),
@@ -91,9 +85,7 @@ def get_feature_stats(
}
def compute_episode_stats(
episode_data: dict[str, list[str] | np.ndarray], features: dict
) -> dict:
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
ep_stats = {}
for key, data in episode_data.items():
if features[key]["dtype"] == "string":
@@ -107,15 +99,12 @@ def compute_episode_stats(
axes_to_reduce = 0 # compute stats over the first axis
keepdims = data.ndim == 1 # keep as np.array
ep_stats[key] = get_feature_stats(
ep_ft_array, axis=axes_to_reduce, keepdims=keepdims
)
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
# finally, we normalize and remove batch dim for images
if features[key]["dtype"] in ["image", "video"]:
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v / 255.0, axis=0)
for k, v in ep_stats[key].items()
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
}
return ep_stats
@@ -130,17 +119,11 @@ def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
)
if v.ndim == 0:
raise ValueError(
"Number of dimensions must be at least 1, and is 0 instead."
)
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
if k == "count" and v.shape != (1,):
raise ValueError(
f"Shape of 'count' must be (1), but is {v.shape} instead."
)
raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
raise ValueError(
f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead."
)
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
def aggregate_feature_stats(

View File

@@ -58,9 +58,7 @@ def resolve_delta_timestamps(
if key == "action" and cfg.action_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
if key.startswith("observation.") and cfg.observation_delta_indices is not None:
delta_timestamps[key] = [
i / ds_meta.fps for i in cfg.observation_delta_indices
]
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
if len(delta_timestamps) == 0:
delta_timestamps = None
@@ -81,9 +79,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
LeRobotDataset | MultiLeRobotDataset
"""
image_transforms = (
ImageTransforms(cfg.dataset.image_transforms)
if cfg.dataset.image_transforms.enable
else None
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
)
if isinstance(cfg.dataset.repo_id, str):
@@ -117,8 +113,6 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
if cfg.dataset.use_imagenet_stats:
for key in dataset.meta.camera_keys:
for stats_type, stats in IMAGENET_STATS.items():
dataset.meta.stats[key][stats_type] = torch.tensor(
stats, dtype=torch.float32
)
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
return dataset

View File

@@ -38,14 +38,10 @@ def safe_stop_image_writer(func):
return wrapper
def image_array_to_pil_image(
image_array: np.ndarray, range_check: bool = True
) -> PIL.Image.Image:
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
# TODO(aliberts): handle 1 channel and 4 for depth images
if image_array.ndim != 3:
raise ValueError(
f"The array has {image_array.ndim} dimensions, but 3 is expected for an image."
)
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
if image_array.shape[0] == 3:
# Transpose from pytorch convention (C, H, W) to (H, W, C)
@@ -131,9 +127,7 @@ class AsyncImageWriter:
self._stopped = False
if num_threads <= 0 and num_processes <= 0:
raise ValueError(
"Number of threads and processes must be greater than zero."
)
raise ValueError("Number of threads and processes must be greater than zero.")
if self.num_processes == 0:
# Use threading
@@ -147,16 +141,12 @@ class AsyncImageWriter:
# Use multiprocessing
self.queue = multiprocessing.JoinableQueue()
for _ in range(self.num_processes):
p = multiprocessing.Process(
target=worker_process, args=(self.queue, self.num_threads)
)
p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads))
p.daemon = True
p.start()
self.processes.append(p)
def save_image(
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path
):
def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path):
if isinstance(image, torch.Tensor):
# Convert tensor to numpy array to minimize main process time
image = image.cpu().numpy()

View File

@@ -108,9 +108,7 @@ class LeRobotDatasetMetadata:
self.episodes = load_episodes(self.root)
if self._version < packaging.version.parse("v2.1"):
self.stats = load_stats(self.root)
self.episodes_stats = backward_compatible_episodes_stats(
self.stats, self.episodes
)
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
else:
self.episodes_stats = load_episodes_stats(self.root)
self.stats = aggregate_stats(list(self.episodes_stats.values()))
@@ -141,9 +139,7 @@ class LeRobotDatasetMetadata:
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
ep_chunk = self.get_episode_chunk(ep_index)
fpath = self.video_path.format(
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index
)
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
return Path(fpath)
def get_episode_chunk(self, ep_index: int) -> int:
@@ -187,11 +183,7 @@ class LeRobotDatasetMetadata:
@property
def camera_keys(self) -> list[str]:
"""Keys to access visual modalities (regardless of their storage method)."""
return [
key
for key, ft in self.features.items()
if ft["dtype"] in ["video", "image"]
]
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
@property
def names(self) -> dict[str, list | dict]:
@@ -240,9 +232,7 @@ class LeRobotDatasetMetadata:
Given a task in natural language, add it to the dictionary of tasks.
"""
if task in self.task_to_task_index:
raise ValueError(
f"The task '{task}' already exists and can't be added twice."
)
raise ValueError(f"The task '{task}' already exists and can't be added twice.")
task_index = self.info["total_tasks"]
self.task_to_task_index[task] = task_index
@@ -285,11 +275,7 @@ class LeRobotDatasetMetadata:
write_episode(episode_dict, self.root)
self.episodes_stats[episode_index] = episode_stats
self.stats = (
aggregate_stats([self.stats, episode_stats])
if self.stats
else episode_stats
)
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
write_episode_stats(episode_index, episode_stats, self.root)
def update_video_info(self) -> None:
@@ -299,9 +285,7 @@ class LeRobotDatasetMetadata:
"""
for key in self.video_keys:
if not self.features[key].get("info", None):
video_path = self.root / self.get_video_file_path(
ep_index=0, vid_key=key
)
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
self.info["features"][key]["info"] = get_video_info(video_path)
def __repr__(self):
@@ -353,17 +337,13 @@ class LeRobotDatasetMetadata:
# as this would break the dict flattening in the stats computation, which uses '/' as separator
for key in features:
if "/" in key:
raise ValueError(
f"Feature names should not contain '/'. Found '/' in feature '{key}'."
)
raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.")
features = {**features, **DEFAULT_FEATURES}
obj.tasks, obj.task_to_task_index = {}, {}
obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
obj.info = create_empty_dataset_info(
CODEBASE_VERSION, fps, robot_type, features, use_videos
)
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
if len(obj.video_keys) > 0 and not use_videos:
raise ValueError()
write_json(obj.info, obj.root / INFO_PATH)
@@ -494,9 +474,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episodes = episodes
self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION
self.video_backend = (
video_backend if video_backend else get_safe_default_codec()
)
self.video_backend = video_backend if video_backend else get_safe_default_codec()
self.delta_indices = None
# Unused attributes
@@ -509,39 +487,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.meta = LeRobotDatasetMetadata(
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
)
if self.episodes is not None and self.meta._version >= packaging.version.parse(
"v2.1"
):
episodes_stats = [
self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes
]
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
self.stats = aggregate_stats(episodes_stats)
# Load actual data
try:
if force_cache_sync:
raise FileNotFoundError
assert all(
(self.root / fpath).is_file()
for fpath in self.get_episodes_file_paths()
)
assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
self.hf_dataset = self.load_hf_dataset()
except (AssertionError, FileNotFoundError, NotADirectoryError):
self.revision = get_safe_version(self.repo_id, self.revision)
self.download_episodes(download_videos)
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(
self.meta.episodes, self.episodes
)
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
# Check timestamps
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
check_timestamps_sync(
timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s
)
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
# Setup delta_indices
if self.delta_timestamps is not None:
@@ -593,9 +560,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
else:
hub_api.upload_folder(**upload_kwargs)
if not hub_api.file_exists(
self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch
):
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
card = create_lerobot_dataset_card(
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
)
@@ -603,12 +568,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
if tag_version:
with contextlib.suppress(RevisionNotFoundError):
hub_api.delete_tag(
self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset"
)
hub_api.create_tag(
self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset"
)
hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
def pull_from_repo(
self,
@@ -640,11 +601,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
def get_episodes_file_paths(self) -> list[Path]:
episodes = (
self.episodes
if self.episodes is not None
else list(range(self.meta.total_episodes))
)
episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes))
fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
if len(self.meta.video_keys) > 0:
video_files = [
@@ -662,10 +619,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
path = str(self.root / "data")
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
else:
files = [
str(self.root / self.meta.get_data_file_path(ep_idx))
for ep_idx in self.episodes
]
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
hf_dataset = load_dataset("parquet", data_files=files, split="train")
# TODO(aliberts): hf_dataset.set_format("torch")
@@ -675,9 +629,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
def create_hf_dataset(self) -> datasets.Dataset:
features = get_hf_features_from_features(self.features)
ft_dict = {col: [] for col in features}
hf_dataset = datasets.Dataset.from_dict(
ft_dict, features=features, split="train"
)
hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
# TODO(aliberts): hf_dataset.set_format("torch")
hf_dataset.set_transform(hf_transform_to_torch)
@@ -691,20 +643,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property
def num_frames(self) -> int:
"""Number of frames in selected episodes."""
return (
len(self.hf_dataset)
if self.hf_dataset is not None
else self.meta.total_frames
)
return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames
@property
def num_episodes(self) -> int:
"""Number of episodes selected."""
return (
len(self.episodes)
if self.episodes is not None
else self.meta.total_episodes
)
return len(self.episodes) if self.episodes is not None else self.meta.total_episodes
@property
def features(self) -> dict[str, dict]:
@@ -718,24 +662,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
else:
return get_hf_features_from_features(self.features)
def _get_query_indices(
self, idx: int, ep_idx: int
) -> tuple[dict[str, list[int | bool]]]:
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
ep_start = self.episode_data_index["from"][ep_idx]
ep_end = self.episode_data_index["to"][ep_idx]
query_indices = {
key: [
max(ep_start.item(), min(ep_end.item() - 1, idx + delta))
for delta in delta_idx
]
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
for key, delta_idx in self.delta_indices.items()
}
padding = { # Pad values outside of current episode range
f"{key}_is_pad": torch.BoolTensor(
[
(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item())
for delta in delta_idx
]
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
)
for key, delta_idx in self.delta_indices.items()
}
@@ -763,9 +699,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if key not in self.meta.video_keys
}
def _query_videos(
self, query_timestamps: dict[str, list[float]], ep_idx: int
) -> dict[str, torch.Tensor]:
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
@@ -774,9 +708,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
item = {}
for vid_key, query_ts in query_timestamps.items():
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
frames = decode_video_frames(
video_path, query_ts, self.tolerance_s, self.video_backend
)
frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
item[vid_key] = frames.squeeze(0)
return item
@@ -830,9 +762,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
current_ep_idx = (
self.meta.total_episodes if episode_index is None else episode_index
)
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
ep_buffer = {}
# size and task are special cases that are not in self.features
ep_buffer["size"] = 0
@@ -841,17 +771,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
return ep_buffer
def _get_image_file_path(
self, episode_index: int, image_key: str, frame_index: int
) -> Path:
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
fpath = DEFAULT_IMAGE_PATH.format(
image_key=image_key, episode_index=episode_index, frame_index=frame_index
)
return self.root / fpath
def _save_image(
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path
) -> None:
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
if self.image_writer is None:
if isinstance(image, torch.Tensor):
image = image.cpu().numpy()
@@ -877,9 +803,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Automatically add frame_index and timestamp to episode buffer
frame_index = self.episode_buffer["size"]
timestamp = (
frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
)
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
self.episode_buffer["frame_index"].append(frame_index)
self.episode_buffer["timestamp"].append(timestamp)
@@ -930,9 +854,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
episode_tasks = list(set(tasks))
episode_index = episode_buffer["episode_index"]
episode_buffer["index"] = np.arange(
self.meta.total_frames, self.meta.total_frames + episode_length
)
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
# Add new tasks to the tasks dictionary
@@ -942,9 +864,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.meta.add_task(task)
# Given tasks in natural language, find their corresponding task indices
episode_buffer["task_index"] = np.array(
[self.meta.get_task_index(task) for task in tasks]
)
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
for key, ft in self.features.items():
# index, episode_index, task_index are already processed above, and image and video
@@ -994,9 +914,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
ep_dataset = datasets.Dataset.from_dict(
episode_dict, features=self.hf_features, split="train"
)
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
ep_dataset = embed_images(ep_dataset)
self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset])
self.hf_dataset.set_transform(hf_transform_to_torch)
@@ -1115,9 +1033,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.delta_timestamps = None
obj.delta_indices = None
obj.episode_data_index = None
obj.video_backend = (
video_backend if video_backend is not None else get_safe_default_codec()
)
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
return obj
@@ -1142,9 +1058,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
super().__init__()
self.repo_ids = repo_ids
self.root = Path(root) if root else HF_LEROBOT_HOME
self.tolerances_s = (
tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
)
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class.
self._datasets = [
@@ -1223,13 +1137,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def features(self) -> datasets.Features:
features = {}
for dataset in self._datasets:
features.update(
{
k: v
for k, v in dataset.hf_features.items()
if k not in self.disabled_features
}
)
features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
return features
@property
@@ -1290,9 +1198,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
continue
break
else:
raise AssertionError(
"We expect the loop to break out as long as the index is within bounds."
)
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
item = self._datasets[dataset_idx][idx - start_idx]
item["dataset_index"] = torch.tensor(dataset_idx)
for data_key in self.disabled_features:

View File

@@ -131,9 +131,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
else:
self._delta_timestamps = None
def _make_data_spec(
self, data_spec: dict[str, Any], buffer_capacity: int
) -> dict[str, dict[str, Any]]:
def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]:
"""Makes the data spec for np.memmap."""
if any(k.startswith("_") for k in data_spec):
raise ValueError(
@@ -208,9 +206,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
# Shift the incoming indices if necessary.
if self.num_frames > 0:
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][
next_index - 1
]
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
data[OnlineBuffer.INDEX_KEY] += last_data_index + 1
@@ -245,11 +241,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
@property
def num_episodes(self) -> int:
return len(
np.unique(
self._data[OnlineBuffer.EPISODE_INDEX_KEY][
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]
]
)
np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
)
@property
@@ -287,9 +279,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY],
)
)[0]
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][
episode_data_indices
]
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices]
for data_key in self.delta_timestamps:
# Note: The logic in this loop is copied from `load_previous_and_future_frames`.
@@ -306,8 +296,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
# Check violated query timestamps are all outside the episode range.
assert (
(query_ts[is_pad] < episode_timestamps[0])
| (episode_timestamps[-1] < query_ts[is_pad])
(query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad])
).all(), (
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}"
") inside the episode range."
@@ -322,9 +311,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
def get_data_by_key(self, key: str) -> torch.Tensor:
"""Returns all data for a given data key as a Tensor."""
return torch.from_numpy(
self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]
)
return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
def compute_sampler_weights(
@@ -355,19 +342,13 @@ def compute_sampler_weights(
- Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not
included here to avoid adding complexity.
"""
if len(offline_dataset) == 0 and (
online_dataset is None or len(online_dataset) == 0
):
raise ValueError(
"At least one of `offline_dataset` or `online_dataset` should be contain data."
)
if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0):
raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.")
if (online_dataset is None) ^ (online_sampling_ratio is None):
raise ValueError(
"`online_dataset` and `online_sampling_ratio` must be provided together or not at all."
)
offline_sampling_ratio = (
0 if online_sampling_ratio is None else 1 - online_sampling_ratio
)
offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio
weights = []

View File

@@ -45,9 +45,7 @@ def concatenate_episodes(ep_dicts):
return data_dict
def save_images_concurrently(
imgs_array: numpy.array, out_dir: Path, max_workers: int = 4
):
def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4):
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
@@ -57,10 +55,7 @@ def save_images_concurrently(
num_images = len(imgs_array)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
[
executor.submit(save_image, imgs_array[i], i, out_dir)
for i in range(num_images)
]
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
def get_default_encoding() -> dict:
@@ -69,8 +64,7 @@ def get_default_encoding() -> dict:
return {
k: v.default
for k, v in signature.parameters.items()
if v.default is not inspect.Parameter.empty
and k in ["vcodec", "pix_fmt", "g", "crf"]
if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"]
}

View File

@@ -58,9 +58,7 @@ class RandomSubsetApply(Transform):
elif not isinstance(n_subset, int):
raise TypeError("n_subset should be an int or None")
elif not (1 <= n_subset <= len(transforms)):
raise ValueError(
f"n_subset should be in the interval [1, {len(transforms)}]"
)
raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]")
self.transforms = transforms
total = sum(p)
@@ -121,36 +119,26 @@ class SharpnessJitter(Transform):
def _check_input(self, sharpness):
if isinstance(sharpness, (int, float)):
if sharpness < 0:
raise ValueError(
"If sharpness is a single number, it must be non negative."
)
raise ValueError("If sharpness is a single number, it must be non negative.")
sharpness = [1.0 - sharpness, 1.0 + sharpness]
sharpness[0] = max(sharpness[0], 0.0)
elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2:
sharpness = [float(v) for v in sharpness]
else:
raise TypeError(
f"{sharpness=} should be a single number or a sequence with length 2."
)
raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.")
if not 0.0 <= sharpness[0] <= sharpness[1]:
raise ValueError(
f"sharpnesss values should be between (0., inf), but got {sharpness}."
)
raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.")
return float(sharpness[0]), float(sharpness[1])
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
sharpness_factor = (
torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
)
sharpness_factor = torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
return {"sharpness_factor": sharpness_factor}
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
sharpness_factor = params["sharpness_factor"]
return self._call_kernel(
F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor
)
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
@dataclass

View File

@@ -52,15 +52,9 @@ STATS_PATH = "meta/stats.json"
EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
TASKS_PATH = "meta/tasks.jsonl"
DEFAULT_VIDEO_PATH = (
"videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
)
DEFAULT_PARQUET_PATH = (
"data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
)
DEFAULT_IMAGE_PATH = (
"images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
)
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
DATASET_CARD_TEMPLATE = """
---
@@ -135,9 +129,7 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
elif isinstance(value, (int, float)):
serialized_dict[key] = value
else:
raise NotImplementedError(
f"The value '{value}' of type '{type(value)}' is not supported."
)
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
return unflatten_dict(serialized_dict)
@@ -216,10 +208,7 @@ def write_task(task_index: int, task: dict, local_dir: Path):
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
tasks = load_jsonlines(local_dir / TASKS_PATH)
tasks = {
item["task_index"]: item["task"]
for item in sorted(tasks, key=lambda x: x["task_index"])
}
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
return tasks, task_to_task_index
@@ -230,10 +219,7 @@ def write_episode(episode: dict, local_dir: Path):
def load_episodes(local_dir: Path) -> dict:
episodes = load_jsonlines(local_dir / EPISODES_PATH)
return {
item["episode_index"]: item
for item in sorted(episodes, key=lambda x: x["episode_index"])
}
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
@@ -286,9 +272,7 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
elif first_item is None:
pass
else:
items_dict[key] = [
x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]
]
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
return items_dict
@@ -341,9 +325,7 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
Otherwise, will throw a `CompatibilityError`.
"""
target_version = (
packaging.version.parse(version)
if not isinstance(version, packaging.version.Version)
else version
packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version
)
hub_versions = get_repo_versions(repo_id)
@@ -364,16 +346,12 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
return f"v{target_version}"
compatibles = [
v
for v in hub_versions
if v.major == target_version.major and v.minor <= target_version.minor
v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor
]
if compatibles:
return_version = max(compatibles)
if return_version < target_version:
logging.warning(
f"Revision {version} for {repo_id} not found, using version v{return_version}"
)
logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}")
return f"v{return_version}"
lower_major = [v for v in hub_versions if v.major < target_version.major]
@@ -480,9 +458,7 @@ def create_empty_dataset_info(
def get_episode_data_index(
episode_dicts: dict[dict], episodes: list[int] | None = None
) -> dict[str, torch.Tensor]:
episode_lengths = {
ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()
}
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
if episodes is not None:
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
@@ -532,9 +508,7 @@ def check_timestamps_sync(
# Mask to ignore differences at the boundaries between episodes
mask = np.ones(len(diffs), dtype=bool)
ignored_diffs = (
episode_data_index["to"][:-1] - 1
) # indices at the end of each episode
ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode
mask[ignored_diffs] = False
filtered_within_tolerance = within_tolerance[mask]
@@ -580,14 +554,10 @@ def check_delta_timestamps(
"""
outside_tolerance = {}
for key, delta_ts in delta_timestamps.items():
within_tolerance = [
abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts
]
within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts]
if not all(within_tolerance):
outside_tolerance[key] = [
ts
for ts, is_within in zip(delta_ts, within_tolerance, strict=True)
if not is_within
ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within
]
if len(outside_tolerance) > 0:
@@ -605,9 +575,7 @@ def check_delta_timestamps(
return True
def get_delta_indices(
delta_timestamps: dict[str, list[float]], fps: int
) -> dict[str, list[int]]:
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
delta_indices = {}
for key, delta_ts in delta_timestamps.items():
delta_indices[key] = [round(d * fps) for d in delta_ts]
@@ -672,9 +640,7 @@ def create_lerobot_dataset_card(
],
)
card_template = (
importlib.resources.files("lerobot.common.datasets") / "card_template.md"
).read_text()
card_template = (importlib.resources.files("lerobot.common.datasets") / "card_template.md").read_text()
return DatasetCard.from_template(
card_data=card_data,
@@ -743,18 +709,14 @@ def validate_frame(frame: dict, features: dict):
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
actual_features = set(frame.keys())
error_message = validate_features_presence(
actual_features, expected_features, optional_features
)
error_message = validate_features_presence(actual_features, expected_features, optional_features)
if "task" in frame:
error_message += validate_feature_string("task", frame["task"])
common_features = actual_features & (expected_features | optional_features)
for name in common_features - {"task"}:
error_message += validate_feature_dtype_and_shape(
name, features[name], frame[name]
)
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
if error_message:
raise ValueError(error_message)
@@ -777,9 +739,7 @@ def validate_features_presence(
return error_message
def validate_feature_dtype_and_shape(
name: str, feature: dict, value: np.ndarray | PILImage.Image | str
):
def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
expected_dtype = feature["dtype"]
expected_shape = feature["shape"]
if is_valid_numpy_dtype_string(expected_dtype):
@@ -789,9 +749,7 @@ def validate_feature_dtype_and_shape(
elif expected_dtype == "string":
return validate_feature_string(name, value)
else:
raise NotImplementedError(
f"The feature dtype '{expected_dtype}' is not implemented yet."
)
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
def validate_feature_numpy_array(
@@ -813,17 +771,13 @@ def validate_feature_numpy_array(
return error_message
def validate_feature_image_or_video(
name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image
):
def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
error_message = ""
if isinstance(value, np.ndarray):
actual_shape = value.shape
c, h, w = expected_shape
if len(actual_shape) != 3 or (
actual_shape != (c, h, w) and actual_shape != (h, w, c)
):
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
elif isinstance(value, PILImage.Image):
pass
@@ -854,9 +808,7 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
)
if episode_buffer["size"] == 0:
raise ValueError(
"You must add one or several frames with `add_frame` before calling `add_episode`."
)
raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
if not buffer_keys == set(features):

View File

@@ -218,9 +218,7 @@ def get_features_from_hf_dataset(
dtype = ft.feature.dtype
shape = (ft.length,)
motor_names = (
robot_config["names"][key]
if robot_config
else [f"motor_{i}" for i in range(ft.length)]
robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)]
)
assert len(motor_names) == shape[0]
names = {"motors": motor_names}
@@ -244,15 +242,11 @@ def get_features_from_hf_dataset(
return features
def add_task_index_by_episodes(
dataset: Dataset, tasks_by_episodes: dict
) -> tuple[Dataset, list[str]]:
def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
df = dataset.to_pandas()
tasks = list(set(tasks_by_episodes.values()))
tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)}
episodes_to_task_index = {
ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()
}
episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()}
df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int)
features = dataset.features
@@ -269,19 +263,10 @@ def add_task_index_from_tasks_col(
# HACK: This is to clean some of the instructions in our version of Open X datasets
prefix_to_clean = "tf.Tensor(b'"
suffix_to_clean = "', shape=(), dtype=string)"
df[tasks_col] = (
df[tasks_col]
.str.removeprefix(prefix_to_clean)
.str.removesuffix(suffix_to_clean)
)
df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean)
# Create task_index col
tasks_by_episode = (
df.groupby("episode_index")[tasks_col]
.unique()
.apply(lambda x: x.tolist())
.to_dict()
)
tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict()
tasks = df[tasks_col].unique().tolist()
tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)}
df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int)
@@ -306,9 +291,7 @@ def split_parquet_by_episodes(
for ep_chunk in range(total_chunks):
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(
episode_chunk=ep_chunk
)
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
for ep_idx in range(ep_chunk_start, ep_chunk_end):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
@@ -340,9 +323,7 @@ def move_videos(
videos_moved = False
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
if len(video_files) == 0:
video_files = [
str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")
]
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")]
videos_moved = True # Videos have already been moved
assert len(video_files) == total_episodes * len(video_keys)
@@ -373,9 +354,7 @@ def move_videos(
target_path = DEFAULT_VIDEO_PATH.format(
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
)
video_file = V1_VIDEO_FILE.format(
video_key=vid_key, episode_index=ep_idx
)
video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
if len(video_dirs) == 1:
video_path = video_dirs[0] / video_file
else:
@@ -392,9 +371,7 @@ def move_videos(
subprocess.run(["git", "push"], cwd=work_dir, check=True)
def fix_lfs_video_files_tracking(
work_dir: Path, lfs_untracked_videos: list[str]
) -> None:
def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None:
"""
HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
there's no other option than to download the actual files and reupload them with lfs tracking.
@@ -418,14 +395,10 @@ def fix_lfs_video_files_tracking(
subprocess.run(["git", "push"], cwd=work_dir, check=True)
def fix_gitattributes(
work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path
) -> None:
def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None:
shutil.copyfile(clean_gittatributes, current_gittatributes)
subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
subprocess.run(
["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True
)
subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True)
subprocess.run(["git", "push"], cwd=work_dir, check=True)
@@ -462,9 +435,7 @@ def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[st
return [f for f in video_files if f not in lfs_tracked_files]
def get_videos_info(
repo_id: str, local_dir: Path, video_keys: list[str], branch: str
) -> dict:
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
# Assumes first episode
video_files = [
DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
@@ -539,31 +510,19 @@ def convert_dataset(
if single_task:
tasks_by_episodes = {ep_idx: single_task for ep_idx in episode_indices}
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
tasks_by_episodes = {
ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()
}
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
elif tasks_path:
tasks_by_episodes = load_json(tasks_path)
tasks_by_episodes = {
int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()
}
tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
tasks_by_episodes = {
ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()
}
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
elif tasks_col:
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(
dataset, tasks_col
)
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col)
else:
raise ValueError
assert set(tasks) == {
task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks
}
tasks = [
{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)
]
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
write_jsonlines(tasks, v20_dir / TASKS_PATH)
features["task_index"] = {
"dtype": "int64",
@@ -593,9 +552,7 @@ def convert_dataset(
clean_gitattr,
branch,
)
videos_info = get_videos_info(
repo_id, v1x_dir, video_keys=video_keys, branch=branch
)
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
for key in video_keys:
features[key]["shape"] = (
videos_info[key].pop("video.height"),
@@ -603,22 +560,15 @@ def convert_dataset(
videos_info[key].pop("video.channels"),
)
features[key]["video_info"] = videos_info[key]
assert math.isclose(
videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3
)
assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
if "encoding" in metadata_v1:
assert (
videos_info[key]["video.pix_fmt"]
== metadata_v1["encoding"]["pix_fmt"]
)
assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
else:
assert metadata_v1.get("video", 0) == 0
videos_info = None
# Split data into 1 parquet file by episode
episode_lengths = split_parquet_by_episodes(
dataset, total_episodes, total_chunks, v20_dir
)
episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
if robot_config is not None:
robot_type = robot_config.type
@@ -656,14 +606,10 @@ def convert_dataset(
}
write_json(metadata_v2_0, v20_dir / INFO_PATH)
convert_stats_to_json(v1x_dir, v20_dir)
card = create_lerobot_dataset_card(
tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs
)
card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(
repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch
)
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(
@@ -674,9 +620,7 @@ def convert_dataset(
)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(
repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch
)
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
hub_api.upload_folder(
repo_id=repo_id,

View File

@@ -35,30 +35,22 @@ def fix_dataset(repo_id: str) -> str:
dataset_info = get_dataset_config_info(repo_id, "default")
with SuppressWarnings():
lerobot_metadata = LeRobotDatasetMetadata(
repo_id, revision=V20, force_cache_sync=True
)
lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True)
meta_features = {
key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"
}
meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"}
parquet_features = set(dataset_info.features)
diff_parquet_meta = parquet_features - meta_features
diff_meta_parquet = meta_features - parquet_features
if diff_parquet_meta:
raise ValueError(
f"In parquet not in info.json: {parquet_features - meta_features}"
)
raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}")
if not diff_meta_parquet:
return f"{repo_id}: skipped (no diff)"
if diff_meta_parquet:
logging.warning(
f"In info.json not in parquet: {meta_features - parquet_features}"
)
logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
assert diff_meta_parquet == {"language_instruction"}
lerobot_metadata.features.pop("language_instruction")
write_info(lerobot_metadata.info, lerobot_metadata.root)

View File

@@ -99,9 +99,7 @@ def convert_dataset(
repo_type="dataset",
)
hub_api.create_tag(
repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset"
)
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
if __name__ == "__main__":

View File

@@ -26,9 +26,7 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import write_episode_stats
def sample_episode_video_frames(
dataset: LeRobotDataset, episode_index: int, ft_key: str
) -> np.ndarray:
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
ep_len = dataset.meta.episodes[episode_index]["length"]
sampled_indices = sample_indices(ep_len)
query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
@@ -51,14 +49,11 @@ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
ep_stats[key] = get_feature_stats(
ep_ft_data, axis=axes_to_reduce, keepdims=keepdims
)
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
if ft["dtype"] in ["image", "video"]: # remove batch dim
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v, axis=0)
for k, v in ep_stats[key].items()
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
}
dataset.meta.episodes_stats[ep_idx] = ep_stats

View File

@@ -65,9 +65,7 @@ def decode_video_frames(
if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
elif backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(
video_path, timestamps, tolerance_s, backend
)
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
else:
raise ValueError(f"Unsupported video backend: {backend}")
@@ -346,9 +344,7 @@ def get_audio_info(video_path: Path | str) -> dict:
"json",
str(video_path),
]
result = subprocess.run(
ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
@@ -362,9 +358,7 @@ def get_audio_info(video_path: Path | str) -> dict:
"has_audio": True,
"audio.channels": audio_stream_info.get("channels", None),
"audio.codec": audio_stream_info.get("codec_name", None),
"audio.bit_rate": int(audio_stream_info["bit_rate"])
if audio_stream_info.get("bit_rate")
else None,
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
"audio.sample_rate": int(audio_stream_info["sample_rate"])
if audio_stream_info.get("sample_rate")
else None,
@@ -386,9 +380,7 @@ def get_video_info(video_path: Path | str) -> dict:
"json",
str(video_path),
]
result = subprocess.run(
ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}")

View File

@@ -61,16 +61,10 @@ class AlohaEnv(EnvConfig):
def __post_init__(self):
if self.obs_type == "pixels":
self.features["top"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(480, 640, 3)
)
self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
elif self.obs_type == "pixels_agent_pos":
self.features["agent_pos"] = PolicyFeature(
type=FeatureType.STATE, shape=(14,)
)
self.features["pixels/top"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(480, 640, 3)
)
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,))
self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
@property
def gym_kwargs(self) -> dict:
@@ -108,13 +102,9 @@ class PushtEnv(EnvConfig):
def __post_init__(self):
if self.obs_type == "pixels_agent_pos":
self.features["pixels"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(384, 384, 3)
)
self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3))
elif self.obs_type == "environment_state_agent_pos":
self.features["environment_state"] = PolicyFeature(
type=FeatureType.ENV, shape=(16,)
)
self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,))
@property
def gym_kwargs(self) -> dict:
@@ -153,9 +143,7 @@ class XarmEnv(EnvConfig):
def __post_init__(self):
if self.obs_type == "pixels_agent_pos":
self.features["agent_pos"] = PolicyFeature(
type=FeatureType.STATE, shape=(4,)
)
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
@property
def gym_kwargs(self) -> dict:

View File

@@ -32,9 +32,7 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
raise ValueError(f"Policy type '{env_type}' is not available.")
def make_env(
cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False
) -> gym.vector.VectorEnv | None:
def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> gym.vector.VectorEnv | None:
"""Makes a gym vector environment according to the config.
Args:
@@ -58,9 +56,7 @@ def make_env(
try:
importlib.import_module(package_name)
except ModuleNotFoundError as e:
print(
f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`"
)
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
raise e
gym_handle = f"{package_name}/{cfg.task}"
@@ -68,18 +64,13 @@ def make_env(
# batched version of the env that returns an observation of shape (b, c)
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
env = env_cls(
[
lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs)
for _ in range(n_envs)
]
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
)
return env
def make_maniskill_env(
cfg: DictConfig, n_envs: int | None = None
) -> gym.vector.VectorEnv | None:
def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
"""Make ManiSkill3 gym environment"""
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
@@ -96,9 +87,7 @@ def make_maniskill_env(
# state should have the size of 25
# env = ConvertToLeRobotEnv(env, n_envs)
# env = PixelWrapper(cfg, env, n_envs)
env._max_episode_steps = env.max_episode_steps = (
50 # gym_utils.find_max_episode_steps_value(env)
)
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
env.unwrapped.metadata["render_fps"] = 20
return env
@@ -125,11 +114,7 @@ class PixelWrapper(gym.Wrapper):
def _get_obs(self, obs):
frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2)
self._frames.append(frame)
return {
"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(
self.env.device
)
}
return {"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(self.env.device)}
def reset(self, seed):
obs, info = self.env.reset() # (seed=seed)
@@ -164,9 +149,7 @@ class ConvertToLeRobotEnv(gym.Wrapper):
images = torch.concat(images, axis=-1)
# flatten the rest of the data which should just be state data
observation = common.flatten_state_dict(
observation, use_torch=True, device=self.base_env.device
)
observation = common.flatten_state_dict(observation, use_torch=True, device=self.base_env.device)
ret = dict()
ret["state"] = observation
ret["pixels"] = images

View File

@@ -46,9 +46,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
# sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, (
f"expect channel last images, but instead got {img.shape=}"
)
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
@@ -81,9 +79,7 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
for key, ft in env_cfg.features.items():
if ft.type is FeatureType.VISUAL:
if len(ft.shape) != 3:
raise ValueError(
f"Number of dimensions of {key} != 3 (shape={ft.shape})"
)
raise ValueError(f"Number of dimensions of {key} != 3 (shape={ft.shape})")
shape = get_channel_first_image_shape(ft.shape)
feature = PolicyFeature(type=ft.type, shape=shape)

View File

@@ -34,13 +34,7 @@ def make_optimizer_and_scheduler(
Returns:
tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`.
"""
params = (
policy.get_optim_params()
if cfg.use_policy_training_preset
else policy.parameters()
)
params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
optimizer = cfg.optimizer.build(params)
lr_scheduler = (
cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
)
lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
return optimizer, lr_scheduler

View File

@@ -102,9 +102,7 @@ def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> No
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
def load_optimizer_state(
optimizer: torch.optim.Optimizer, save_dir: Path
) -> torch.optim.Optimizer:
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
current_state_dict = optimizer.state_dict()
flat_state = load_file(save_dir / OPTIMIZER_STATE)
state = unflatten_dict(flat_state)

View File

@@ -36,9 +36,7 @@ class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
return self.get_choice_name(self.__class__)
@abc.abstractmethod
def build(
self, optimizer: Optimizer, num_training_steps: int
) -> LRScheduler | None:
def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler | None:
raise NotImplementedError
@@ -79,11 +77,7 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
)
return max(
0.0,
0.5
* (
1.0
+ math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)
),
0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)),
)
return LambdaLR(optimizer, lr_lambda, -1)
@@ -111,9 +105,7 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
def cosine_decay_schedule(current_step):
step = min(current_step, self.num_decay_steps)
cosine_decay = 0.5 * (
1 + math.cos(math.pi * step / self.num_decay_steps)
)
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
alpha = self.decay_lr / self.peak_lr
decayed = (1 - alpha) * cosine_decay + alpha
return decayed
@@ -132,8 +124,6 @@ def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler:
state_dict = deserialize_json_into_object(
save_dir / SCHEDULER_STATE, scheduler.state_dict()
)
state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict())
scheduler.load_state_dict(state_dict)
return scheduler

View File

@@ -171,9 +171,7 @@ class ACTConfig(PreTrainedConfig):
def validate_features(self) -> None:
if not self.image_features and not self.env_state_feature:
raise ValueError(
"You must provide at least one image or the environment state among the inputs."
)
raise ValueError("You must provide at least one image or the environment state among the inputs.")
@property
def observation_delta_indices(self) -> None:

View File

@@ -63,9 +63,7 @@ class ACTPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(
config.input_features, config.normalization_mapping, dataset_stats
)
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
@@ -76,9 +74,7 @@ class ACTPolicy(PreTrainedPolicy):
self.model = ACT(config)
if config.temporal_ensemble_coeff is not None:
self.temporal_ensembler = ACTTemporalEnsembler(
config.temporal_ensemble_coeff, config.chunk_size
)
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
self.reset()
@@ -122,12 +118,8 @@ class ACTPolicy(PreTrainedPolicy):
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [
batch[key] for key in self.config.image_features
]
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [batch[key] for key in self.config.image_features]
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
# we are ensembling over.
@@ -154,19 +146,14 @@ class ACTPolicy(PreTrainedPolicy):
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [
batch[key] for key in self.config.image_features
]
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [batch[key] for key in self.config.image_features]
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none")
* ~batch["action_is_pad"].unsqueeze(-1)
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean()
loss_dict = {"l1_loss": l1_loss.item()}
@@ -176,12 +163,7 @@ class ACTPolicy(PreTrainedPolicy):
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (
(
-0.5
* (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())
)
.sum(-1)
.mean()
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kld_loss"] = mean_kld.item()
loss = l1_loss + mean_kld * self.config.kl_weight
@@ -235,9 +217,7 @@ class ACTTemporalEnsembler:
```
"""
self.chunk_size = chunk_size
self.ensemble_weights = torch.exp(
-temporal_ensemble_coeff * torch.arange(chunk_size)
)
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
self.reset()
@@ -253,9 +233,7 @@ class ACTTemporalEnsembler:
time steps, and pop/return the next batch of actions in the sequence.
"""
self.ensemble_weights = self.ensemble_weights.to(device=actions.device)
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(
device=actions.device
)
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device)
if self.ensembled_actions is None:
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
# time step of the episode.
@@ -270,22 +248,12 @@ class ACTTemporalEnsembler:
else:
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
# the online update for those entries.
self.ensembled_actions *= self.ensemble_weights_cumsum[
self.ensembled_actions_count - 1
]
self.ensembled_actions += (
actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
)
self.ensembled_actions /= self.ensemble_weights_cumsum[
self.ensembled_actions_count
]
self.ensembled_actions_count = torch.clamp(
self.ensembled_actions_count + 1, max=self.chunk_size
)
self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1]
self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count]
self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size)
# The last action, which has no prior online average, needs to get concatenated onto the end.
self.ensembled_actions = torch.cat(
[self.ensembled_actions, actions[:, -1:]], dim=1
)
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
self.ensembled_actions_count = torch.cat(
[
self.ensembled_actions_count,
@@ -356,9 +324,7 @@ class ACT(nn.Module):
config.dim_model,
)
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(
config.dim_model, config.latent_dim * 2
)
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
# dimension.
num_input_token_encoder = 1 + config.chunk_size
@@ -366,9 +332,7 @@ class ACT(nn.Module):
num_input_token_encoder += 1
self.register_buffer(
"vae_encoder_pos_enc",
create_sinusoidal_pos_embedding(
num_input_token_encoder, config.dim_model
).unsqueeze(0),
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
)
# Backbone for image feature extraction.
@@ -385,9 +349,7 @@ class ACT(nn.Module):
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
# feature map).
# Note: The forward method of this returns a dict: {"feature_map": output}.
self.backbone = IntermediateLayerGetter(
backbone_model, return_layers={"layer4": "feature_map"}
)
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
# Transformer (acts as VAE decoder when training with the variational objective).
self.encoder = ACTEncoder(config)
@@ -416,18 +378,14 @@ class ACT(nn.Module):
n_1d_tokens += 1
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
if self.config.image_features:
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(
config.dim_model // 2
)
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
# Transformer decoder.
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)
# Final action regression head on the output of the transformer's decoder.
self.action_head = nn.Linear(
config.dim_model, self.config.action_feature.shape[0]
)
self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0])
self._reset_parameters()
@@ -437,9 +395,7 @@ class ACT(nn.Module):
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(
self, batch: dict[str, Tensor]
) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
`batch` should have the following structure:
@@ -475,13 +431,9 @@ class ACT(nn.Module):
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D)
if self.config.robot_state_feature:
robot_state_embed = self.vae_encoder_robot_state_input_proj(
batch["observation.state"]
)
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(
batch["action"]
) # (B, S, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
if self.config.robot_state_feature:
vae_encoder_input = [
@@ -526,26 +478,20 @@ class ACT(nn.Module):
# When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
latent_sample = torch.zeros(
[batch_size, self.config.latent_dim], dtype=torch.float32
).to(batch["observation.state"].device)
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
batch["observation.state"].device
)
# Prepare transformer encoder inputs.
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
encoder_in_pos_embed = list(
self.encoder_1d_feature_pos_embed.weight.unsqueeze(1)
)
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
# Robot state token.
if self.config.robot_state_feature:
encoder_in_tokens.append(
self.encoder_robot_state_input_proj(batch["observation.state"])
)
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
# Environment state token.
if self.config.env_state_feature:
encoder_in_tokens.append(
self.encoder_env_state_input_proj(
batch["observation.environment_state"]
)
self.encoder_env_state_input_proj(batch["observation.environment_state"])
)
# Camera observation features and positional embeddings.
@@ -556,9 +502,7 @@ class ACT(nn.Module):
# For a list of images, the H and W may vary but H*W is constant.
for img in batch["observation.images"]:
cam_features = self.backbone(img)["feature_map"]
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(
dtype=cam_features.dtype
)
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features)
# Rearrange features to (sequence, batch, dim).
@@ -604,14 +548,8 @@ class ACTEncoder(nn.Module):
def __init__(self, config: ACTConfig, is_vae_encoder: bool = False):
super().__init__()
self.is_vae_encoder = is_vae_encoder
num_layers = (
config.n_vae_encoder_layers
if self.is_vae_encoder
else config.n_encoder_layers
)
self.layers = nn.ModuleList(
[ACTEncoderLayer(config) for _ in range(num_layers)]
)
num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers
self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)])
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
def forward(
@@ -629,9 +567,7 @@ class ACTEncoder(nn.Module):
class ACTEncoderLayer(nn.Module):
def __init__(self, config: ACTConfig):
super().__init__()
self.self_attn = nn.MultiheadAttention(
config.dim_model, config.n_heads, dropout=config.dropout
)
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
# Feed forward layers.
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
@@ -646,9 +582,7 @@ class ACTEncoderLayer(nn.Module):
self.activation = get_activation_fn(config.feedforward_activation)
self.pre_norm = config.pre_norm
def forward(
self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
) -> Tensor:
def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor:
skip = x
if self.pre_norm:
x = self.norm1(x)
@@ -673,9 +607,7 @@ class ACTDecoder(nn.Module):
def __init__(self, config: ACTConfig):
"""Convenience module for running multiple decoder layers followed by normalization."""
super().__init__()
self.layers = nn.ModuleList(
[ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)]
)
self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)])
self.norm = nn.LayerNorm(config.dim_model)
def forward(
@@ -700,12 +632,8 @@ class ACTDecoder(nn.Module):
class ACTDecoderLayer(nn.Module):
def __init__(self, config: ACTConfig):
super().__init__()
self.self_attn = nn.MultiheadAttention(
config.dim_model, config.n_heads, dropout=config.dropout
)
self.multihead_attn = nn.MultiheadAttention(
config.dim_model, config.n_heads, dropout=config.dropout
)
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
# Feed forward layers.
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
@@ -746,9 +674,7 @@ class ACTDecoderLayer(nn.Module):
if self.pre_norm:
x = self.norm1(x)
q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
x = self.self_attn(q, k, value=x)[
0
] # select just the output, not the attention weights
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
x = skip + self.dropout1(x)
if self.pre_norm:
skip = x
@@ -785,14 +711,9 @@ def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tenso
"""
def get_position_angle_vec(position):
return [
position / np.power(10000, 2 * (hid_j // 2) / dimension)
for hid_j in range(dimension)
]
return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)]
sinusoid_table = np.array(
[get_position_angle_vec(pos_i) for pos_i in range(num_positions)]
)
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.from_numpy(sinusoid_table).float()
@@ -837,9 +758,7 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
inverse_frequency = self._temperature ** (
2
* (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2)
/ self.dimension
2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
)
x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
@@ -847,15 +766,9 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
# Note: this stack then flatten operation results in interleaved sine and cosine terms.
# pos_embed_x and pos_embed_y are (1, H, W, C // 2).
pos_embed_x = torch.stack(
(x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1
).flatten(3)
pos_embed_y = torch.stack(
(y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1
).flatten(3)
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(
0, 3, 1, 2
) # (1, C, H, W)
pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3)
pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3)
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W)
return pos_embed

View File

@@ -205,16 +205,11 @@ class DiffusionConfig(PreTrainedConfig):
def validate_features(self) -> None:
if len(self.image_features) == 0 and self.env_state_feature is None:
raise ValueError(
"You must provide at least one image or the environment state among the inputs."
)
raise ValueError("You must provide at least one image or the environment state among the inputs.")
if self.crop_shape is not None:
for key, image_ft in self.image_features.items():
if (
self.crop_shape[0] > image_ft.shape[1]
or self.crop_shape[1] > image_ft.shape[2]
):
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
raise ValueError(
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for "

View File

@@ -70,9 +70,7 @@ class DiffusionPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(
config.input_features, config.normalization_mapping, dataset_stats
)
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
@@ -99,9 +97,7 @@ class DiffusionPolicy(PreTrainedPolicy):
if self.config.image_features:
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque(
maxlen=self.config.n_obs_steps
)
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
@@ -127,9 +123,7 @@ class DiffusionPolicy(PreTrainedPolicy):
"""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
@@ -138,11 +132,7 @@ class DiffusionPolicy(PreTrainedPolicy):
if len(self._queues["action"]) == 0:
# stack n latest observations from the queue
batch = {
k: torch.stack(list(self._queues[k]), dim=1)
for k in batch
if k in self._queues
}
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.diffusion.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary?
@@ -157,9 +147,7 @@ class DiffusionPolicy(PreTrainedPolicy):
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
@@ -201,9 +189,7 @@ class DiffusionModel(nn.Module):
if self.config.env_state_feature:
global_cond_dim += self.config.env_state_feature.shape[0]
self.unet = DiffusionConditionalUnet1d(
config, global_cond_dim=global_cond_dim * config.n_obs_steps
)
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type,
@@ -249,9 +235,7 @@ class DiffusionModel(nn.Module):
global_cond=global_cond,
)
# Compute previous image: x_t -> x_t-1
sample = self.noise_scheduler.step(
model_output, t, sample, generator=generator
).prev_sample
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
return sample
@@ -263,15 +247,11 @@ class DiffusionModel(nn.Module):
if self.config.image_features:
if self.config.use_separate_rgb_encoder_per_camera:
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
images_per_camera = einops.rearrange(
batch["observation.images"], "b s n ... -> n (b s) ..."
)
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
img_features_list = torch.cat(
[
encoder(images)
for encoder, images in zip(
self.rgb_encoder, images_per_camera, strict=True
)
for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True)
]
)
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the
@@ -285,9 +265,7 @@ class DiffusionModel(nn.Module):
else:
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
img_features = self.rgb_encoder(
einops.rearrange(
batch["observation.images"], "b s n ... -> (b s n) ..."
)
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
)
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features).
@@ -381,9 +359,7 @@ class DiffusionModel(nn.Module):
elif self.config.prediction_type == "sample":
target = batch["action"]
else:
raise ValueError(
f"Unsupported prediction type {self.config.prediction_type}"
)
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
loss = F.mse_loss(pred, target, reduction="none")
@@ -443,9 +419,7 @@ class SpatialSoftmax(nn.Module):
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
# and causes a small degradation in pc_success of pre-trained models.
pos_x, pos_y = np.meshgrid(
np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)
)
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
# register as buffer so it's moved to the correct device.
@@ -487,9 +461,7 @@ class DiffusionRgbEncoder(nn.Module):
# Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(
config.crop_shape
)
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
@@ -510,9 +482,7 @@ class DiffusionRgbEncoder(nn.Module):
self.backbone = _replace_submodules(
root_module=self.backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features // 16, num_channels=x.num_features
),
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
)
# Set up pooling and final layers.
@@ -523,15 +493,11 @@ class DiffusionRgbEncoder(nn.Module):
# Note: we have a check in the config class to make sure all images have the same shape.
images_shape = next(iter(config.image_features.values())).shape
dummy_shape_h_w = (
config.crop_shape if config.crop_shape is not None else images_shape[1:]
)
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
self.pool = SpatialSoftmax(
feature_map_shape, num_kp=config.spatial_softmax_num_keypoints
)
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU()
@@ -573,11 +539,7 @@ def _replace_submodules(
if predicate(root_module):
return func(root_module)
replace_list = [
k.split(".")
for k, m in root_module.named_modules(remove_duplicate=True)
if predicate(m)
]
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
for *parents, k in replace_list:
parent_module = root_module
if len(parents) > 0:
@@ -592,9 +554,7 @@ def _replace_submodules(
else:
setattr(parent_module, k, tgt_module)
# verify that all BN are replaced
assert not any(
predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)
)
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
return root_module
@@ -622,9 +582,7 @@ class DiffusionConv1dBlock(nn.Module):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(
inp_channels, out_channels, kernel_size, padding=kernel_size // 2
),
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
nn.GroupNorm(n_groups, out_channels),
nn.Mish(),
)
@@ -647,13 +605,9 @@ class DiffusionConditionalUnet1d(nn.Module):
# Encoder for the diffusion timestep.
self.diffusion_step_encoder = nn.Sequential(
DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim),
nn.Linear(
config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4
),
nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4),
nn.Mish(),
nn.Linear(
config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim
),
nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim),
)
# The FiLM conditioning dimension.
@@ -678,16 +632,10 @@ class DiffusionConditionalUnet1d(nn.Module):
self.down_modules.append(
nn.ModuleList(
[
DiffusionConditionalResidualBlock1d(
dim_in, dim_out, **common_res_block_kwargs
),
DiffusionConditionalResidualBlock1d(
dim_out, dim_out, **common_res_block_kwargs
),
DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
# Downsample as long as it is not the last block.
nn.Conv1d(dim_out, dim_out, 3, 2, 1)
if not is_last
else nn.Identity(),
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
]
)
)
@@ -716,24 +664,16 @@ class DiffusionConditionalUnet1d(nn.Module):
nn.ModuleList(
[
# dim_in * 2, because it takes the encoder's skip connection as well
DiffusionConditionalResidualBlock1d(
dim_in * 2, dim_out, **common_res_block_kwargs
),
DiffusionConditionalResidualBlock1d(
dim_out, dim_out, **common_res_block_kwargs
),
DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
# Upsample as long as it is not the last block.
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1)
if not is_last
else nn.Identity(),
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
]
)
)
self.final_conv = nn.Sequential(
DiffusionConv1dBlock(
config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size
),
DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
nn.Conv1d(config.down_dims[0], config.action_feature.shape[0], 1),
)
@@ -801,23 +741,17 @@ class DiffusionConditionalResidualBlock1d(nn.Module):
self.use_film_scale_modulation = use_film_scale_modulation
self.out_channels = out_channels
self.conv1 = DiffusionConv1dBlock(
in_channels, out_channels, kernel_size, n_groups=n_groups
)
self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
self.conv2 = DiffusionConv1dBlock(
out_channels, out_channels, kernel_size, n_groups=n_groups
)
self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
# A final convolution for dimension matching the residual (if needed).
self.residual_conv = (
nn.Conv1d(in_channels, out_channels, 1)
if in_channels != out_channels
else nn.Identity()
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
)
def forward(self, x: Tensor, cond: Tensor) -> Tensor:

View File

@@ -104,9 +104,7 @@ def make_policy(
PreTrainedPolicy: _description_
"""
if bool(ds_meta) == bool(env_cfg):
raise ValueError(
"Either one of a dataset metadata or a sim env must be provided."
)
raise ValueError("Either one of a dataset metadata or a sim env must be provided.")
# NOTE: Currently, if you try to run vqbet with mps backend, you'll get this error.
# TODO(aliberts, rcadene): Implement a check_backend_compatibility in policies?
@@ -136,12 +134,8 @@ def make_policy(
)
features = env_to_policy_features(env_cfg)
cfg.output_features = {
key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION
}
cfg.input_features = {
key: ft for key, ft in features.items() if key not in cfg.output_features
}
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
kwargs["config"] = cfg
if cfg.pretrained_path:

View File

@@ -7,9 +7,7 @@ from torch import Tensor, nn
from .configuration_classifier import ClassifierConfig
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
@@ -53,9 +51,7 @@ class Classifier(
super().__init__()
self.config = config
# self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
encoder = AutoModel.from_pretrained(
self.config.model_name, trust_remote_code=True
)
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
# Extract vision model if we're given a multimodal model
if hasattr(encoder, "vision_model"):
logging.info("Multimodal model detected - using vision encoder only")
@@ -81,9 +77,7 @@ class Classifier(
self.feature_dim = self.encoder.fc.in_features
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
elif hasattr(self.encoder.config, "hidden_sizes"):
self.feature_dim = self.encoder.config.hidden_sizes[
-1
] # Last channel dimension
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
else:
raise ValueError("Unsupported CNN architecture")
@@ -103,9 +97,7 @@ class Classifier(
if hasattr(self.encoder.config, "hidden_size"):
input_dim = self.encoder.config.hidden_size
else:
raise ValueError(
"Unsupported transformer architecture since hidden_size is not found"
)
raise ValueError("Unsupported transformer architecture since hidden_size is not found")
self.classifier_head = nn.Sequential(
nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim),
@@ -141,10 +133,7 @@ class Classifier(
return features
else: # Transformer models
outputs = self.encoder(processed)
if (
hasattr(outputs, "pooler_output")
and outputs.pooler_output is not None
):
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
return outputs.pooler_output
return outputs.last_hidden_state[:, 0, :]
@@ -160,9 +149,7 @@ class Classifier(
else:
probabilities = torch.softmax(logits, dim=-1)
return ClassifierOutput(
logits=logits, probabilities=probabilities, hidden_states=encoder_outputs
)
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
def predict_reward(self, x, threshold=0.6):
if self.config.num_classes == 2:

View File

@@ -82,43 +82,25 @@ def create_stats_buffers(
if stats:
if isinstance(stats[key]["mean"], np.ndarray):
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(
dtype=torch.float32
)
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(
dtype=torch.float32
)
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(
dtype=torch.float32
)
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(
dtype=torch.float32
)
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
elif isinstance(stats[key]["mean"], torch.Tensor):
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = (
stats[key]["mean"].clone().to(dtype=torch.float32)
)
buffer["std"].data = (
stats[key]["std"].clone().to(dtype=torch.float32)
)
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = (
stats[key]["min"].clone().to(dtype=torch.float32)
)
buffer["max"].data = (
stats[key]["max"].clone().to(dtype=torch.float32)
)
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
else:
type_ = type(stats[key]["mean"])
raise ValueError(
f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead."
)
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
stats_buffers[key] = buffer
return stats_buffers

View File

@@ -44,9 +44,7 @@ def main():
else:
dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human"
ckpt_torch_dir = (
Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
)
ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}"
save_dir = Path(f"../openpi/data/{model_name}/save")
@@ -72,9 +70,7 @@ def main():
# Create LeRobot batch from Jax
batch = {}
for cam_key, uint_chw_array in example["images"].items():
batch[f"observation.images.{cam_key}"] = (
torch.from_numpy(uint_chw_array) / 255.0
)
batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
batch["observation.state"] = torch.from_numpy(example["state"])
batch["action"] = torch.from_numpy(outputs["actions"])
batch["task"] = example["prompt"]

View File

@@ -54,9 +54,7 @@ def get_paligemma_config(precision: str):
"projector_hidden_act": "gelu_fast",
"vision_use_head": False,
}
final_config = PaliGemmaConfig(
text_config=text_config, vision_config=vision_config, **config
)
final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config)
return final_config

View File

@@ -322,9 +322,7 @@ def update_keys_with_prefix(d: dict, prefix: str) -> dict:
return {f"{prefix}{key}": value for key, value in d.items()}
def convert_pi0_checkpoint(
checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str
):
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str):
# Break down orbax ckpts - they are in OCDBT
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
# process projection params
@@ -384,9 +382,7 @@ def convert_pi0_checkpoint(
# gemma_config=gemma_config, paligemma_config=paligemma_config)
pi0_model = PI0Policy(pi0_config)
paligemma_params = update_keys_with_prefix(
paligemma_params, "model.paligemma_with_expert."
)
paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.")
gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.")
projection_params = update_keys_with_prefix(projection_params, "model.")

View File

@@ -193,9 +193,7 @@ def aloha_gripper_to_angular(value):
# This is the inverse of the angular to linear transformation inside the Interbotix code.
def linear_to_radian(linear_position, arm_length, horn_radius):
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (
2 * horn_radius * linear_position
)
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
return safe_arcsin(value)
# The constants are taken from the Interbotix code.
@@ -246,9 +244,7 @@ class PI0Policy(PreTrainedPolicy):
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(
config.input_features, config.normalization_mapping, dataset_stats
)
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
@@ -256,9 +252,7 @@ class PI0Policy(PreTrainedPolicy):
config.output_features, config.normalization_mapping, dataset_stats
)
self.language_tokenizer = AutoTokenizer.from_pretrained(
"google/paligemma-3b-pt-224"
)
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
self.model = PI0FlowMatching(config)
self.reset()
@@ -271,9 +265,7 @@ class PI0Policy(PreTrainedPolicy):
return self.parameters()
@torch.no_grad
def select_action(
self, batch: dict[str, Tensor], noise: Tensor | None = None
) -> Tensor:
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
@@ -312,9 +304,7 @@ class PI0Policy(PreTrainedPolicy):
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def forward(
self, batch: dict[str, Tensor], noise=None, time=None
) -> tuple[Tensor, dict[str, Tensor]]:
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
"""Do a full training forward pass to compute the loss"""
if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
@@ -330,9 +320,7 @@ class PI0Policy(PreTrainedPolicy):
actions_is_pad = batch.get("action_is_pad")
loss_dict = {}
losses = self.model.forward(
images, img_masks, lang_tokens, lang_masks, state, actions, noise, time
)
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
loss_dict["losses_after_forward"] = losses.clone()
if actions_is_pad is not None:
@@ -359,9 +347,7 @@ class PI0Policy(PreTrainedPolicy):
img_masks = []
present_img_keys = [key for key in self.config.image_features if key in batch]
missing_img_keys = [
key for key in self.config.image_features if key not in batch
]
missing_img_keys = [key for key in self.config.image_features if key not in batch]
if len(present_img_keys) == 0:
raise ValueError(
@@ -373,9 +359,7 @@ class PI0Policy(PreTrainedPolicy):
img = batch[key]
if self.config.resize_imgs_with_padding is not None:
img = resize_with_pad(
img, *self.config.resize_imgs_with_padding, pad_value=0
)
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
# Normalize from range [0,1] to [-1,1] as expacted by siglip
img = img * 2.0 - 1.0
@@ -414,9 +398,7 @@ class PI0Policy(PreTrainedPolicy):
return_tensors="pt",
)
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
lang_masks = tokenized_prompt["attention_mask"].to(
device=device, dtype=torch.bool
)
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
return lang_tokens, lang_masks
@@ -435,9 +417,7 @@ class PI0Policy(PreTrainedPolicy):
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular(
actions[:, :, motor_idx]
)
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
return actions
def _pi_aloha_encode_actions_inv(self, actions):
@@ -446,9 +426,7 @@ class PI0Policy(PreTrainedPolicy):
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(
actions[:, :, motor_idx]
)
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
return actions
def prepare_state(self, batch):
@@ -498,25 +476,15 @@ class PI0FlowMatching(nn.Module):
train_expert_only=self.config.train_expert_only,
attention_implementation=self.config.attention_implementation,
)
self.paligemma_with_expert = PaliGemmaWithExpertModel(
paligemma_with_export_config
)
self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
# Projections are float32
self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
self.action_in_proj = nn.Linear(
self.config.max_action_dim, self.config.proj_width
)
self.action_out_proj = nn.Linear(
self.config.proj_width, self.config.max_action_dim
)
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width)
self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim)
self.action_time_mlp_in = nn.Linear(
self.config.proj_width * 2, self.config.proj_width
)
self.action_time_mlp_out = nn.Linear(
self.config.proj_width, self.config.proj_width
)
self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width)
self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
self.set_requires_grad()
@@ -560,9 +528,7 @@ class PI0FlowMatching(nn.Module):
# Normalize image embeddings
img_emb_dim = img_emb.shape[-1]
img_emb = img_emb * torch.tensor(
img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device
)
img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
bsize, num_img_embs = img_emb.shape[:2]
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
@@ -637,9 +603,7 @@ class PI0FlowMatching(nn.Module):
embs.append(action_time_emb)
bsize, action_time_dim = action_time_emb.shape[:2]
action_time_mask = torch.ones(
bsize, action_time_dim, dtype=torch.bool, device=device
)
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
pad_masks.append(action_time_mask)
# Set attention masks so that image, language and state inputs do not attend to action tokens
@@ -677,9 +641,7 @@ class PI0FlowMatching(nn.Module):
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
state, x_t, time
)
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
@@ -703,9 +665,7 @@ class PI0FlowMatching(nn.Module):
losses = F.mse_loss(u_t, v_t, reduction="none")
return losses
def sample_actions(
self, images, img_masks, lang_tokens, lang_masks, state, noise=None
) -> Tensor:
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
bsize = state.shape[0]
device = state.device
@@ -763,16 +723,12 @@ class PI0FlowMatching(nn.Module):
timestep,
):
"""Apply one denoising step of the noise `x_t` at a given timestep."""
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
state, x_t, timestep
)
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep)
suffix_len = suffix_pad_masks.shape[1]
batch_size = prefix_pad_masks.shape[0]
prefix_len = prefix_pad_masks.shape[1]
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(
batch_size, suffix_len, prefix_len
)
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)

View File

@@ -39,13 +39,9 @@ def apply_rope(x, positions, max_wavelength=10_000):
dtype = x.dtype
x = x.to(torch.float32)
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(
d_half, dtype=torch.float32, device=device
)
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
timescale = max_wavelength**freq_exponents
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(
torch.float32
)
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
radians = radians[..., None, :]
@@ -178,9 +174,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
def __init__(self, config: PaliGemmaWithExpertConfig):
super().__init__(config=config)
self.config = config
self.paligemma = PaliGemmaForConditionalGeneration(
config=config.paligemma_config
)
self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
# Remove unused embed_tokens
self.gemma_expert.model.embed_tokens = None
@@ -297,9 +291,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
# the max len, then we (for instance) double the cache size. This implementation already exists
# in `transformers`. (molbap)
key_states = torch.cat(
[past_key_values[layer_idx]["key_states"], key_states], dim=1
)
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
value_states = torch.cat(
[past_key_values[layer_idx]["value_states"], value_states],
dim=1,
@@ -392,9 +384,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
value_states,
):
num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
num_key_value_heads = (
self.config.paligemma_config.text_config.num_key_value_heads
)
num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
num_key_value_groups = num_att_heads // num_key_value_heads
# query_states: batch_size, sequence_length, num_att_head, head_dim
@@ -442,9 +432,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
att_weights *= head_dim**-0.5
big_neg = -2.3819763e38 # See gemma/modules.py
masked_att_weights = torch.where(
attention_mask[:, None, :, :], att_weights, big_neg
)
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
probs = nn.functional.softmax(masked_att_weights, dim=-1)
probs = probs.to(dtype=value_states.dtype)
@@ -456,8 +444,6 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
att_output = att_output.permute(0, 2, 1, 3)
# we use -1 because sequence length can change
att_output = att_output.reshape(
batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim
)
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
return att_output

View File

@@ -71,9 +71,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
def _save_pretrained(self, save_directory: Path) -> None:
self.config._save_pretrained(save_directory)
model_to_save = self.module if hasattr(self, "module") else self
save_model_as_safetensor(
model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE)
)
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
@classmethod
def from_pretrained(
@@ -112,9 +110,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
if os.path.isdir(model_id):
print("Loading weights from local directory")
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
policy = cls._load_as_safetensor(
instance, model_file, config.device, strict
)
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
else:
try:
model_file = hf_hub_download(
@@ -128,9 +124,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
token=token,
local_files_only=local_files_only,
)
policy = cls._load_as_safetensor(
instance, model_file, config.device, strict
)
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
@@ -141,12 +135,8 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
return policy
@classmethod
def _load_as_safetensor(
cls, model: T, model_file: str, map_location: str, strict: bool
) -> T:
if packaging.version.parse(safetensors.__version__) < packaging.version.parse(
"0.4.3"
):
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"):
load_model_as_safetensor(model, model_file, strict=strict)
if map_location != "cpu":
logging.warning(
@@ -157,9 +147,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
)
model.to(map_location)
else:
safetensors.torch.load_model(
model, model_file, strict=strict, device=map_location
)
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
return model
# def generate_model_card(self, *args, **kwargs) -> ModelCard:

View File

@@ -48,9 +48,7 @@ class SACConfig:
"observation.state": {"min": [-1, -1, -1, -1], "max": [1, 1, 1, 1]},
}
)
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"}
)
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
output_normalization_params: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: {
"action": {"min": [-1, -1], "max": [1, 1]},

View File

@@ -18,8 +18,8 @@
# TODO: (1) better device management
import math
from typing import Callable, Optional, Tuple, Union, Dict, List
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union
import einops
import numpy as np
@@ -124,17 +124,13 @@ class SACPolicy(
self.actor = Policy(
encoder=encoder_actor,
network=MLP(
input_dim=encoder_actor.output_dim, **config.actor_network_kwargs
),
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
action_dim=config.output_shapes["action"][0],
encoder_is_shared=config.shared_encoder,
**config.policy_kwargs,
)
if config.target_entropy is None:
config.target_entropy = (
-np.prod(config.output_shapes["action"][0]) / 2
) # (-dim(A)/2)
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
# TODO (azouitine): Handle the case where the temparameter is a fixed
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
@@ -146,10 +142,11 @@ class SACPolicy(
def _save_pretrained(self, save_directory):
"""Custom save method to handle TensorDict properly"""
import os
import json
import os
from dataclasses import asdict
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE
from safetensors.torch import save_model
save_model(self, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE))
@@ -177,12 +174,14 @@ class SACPolicy(
**model_kwargs,
) -> "SACPolicy":
"""Custom load method to handle loading SAC policy from saved files"""
import os
import json
import os
from pathlib import Path
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE
from safetensors.torch import load_model
from lerobot.common.policies.sac.configuration_sac import SACConfig
# Check if model_id is a local path or a hub model ID
@@ -302,14 +301,10 @@ class SACPolicy(
) -> Tensor:
self.temperature = self.log_alpha.exp().item()
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(
next_observations, next_observation_features
)
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
next_action_preds = self.unnormalize_outputs({"action": next_action_preds})[
"action"
]
next_action_preds = self.unnormalize_outputs({"action": next_action_preds})["action"]
# 2- compute q targets
q_targets = self.critic_forward(
@@ -353,21 +348,15 @@ class SACPolicy(
).sum()
return critics_loss
def compute_loss_temperature(
self, observations, observation_features: Tensor | None = None
) -> Tensor:
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
"""Compute the temperature loss"""
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.actor(observations, observation_features)
temperature_loss = (
-self.log_alpha.exp() * (log_probs + self.config.target_entropy)
).mean()
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
return temperature_loss
def compute_loss_actor(
self, observations, observation_features: Tensor | None = None
) -> Tensor:
def compute_loss_actor(self, observations, observation_features: Tensor | None = None) -> Tensor:
self.temperature = self.log_alpha.exp().item()
actions_pi, log_probs, _ = self.actor(observations, observation_features)
@@ -408,11 +397,7 @@ class MLP(nn.Module):
if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[0]))
layers.append(
activations
if isinstance(activations, nn.Module)
else getattr(nn, activations)()
)
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
# Rest of the layers
for i in range(1, len(hidden_dims)):
@@ -424,11 +409,7 @@ class MLP(nn.Module):
layers.append(nn.LayerNorm(hidden_dims[i]))
# If we're at the final layer and a final activation is specified, use it
if (
i + 1 == len(hidden_dims)
and activate_final
and final_activation is not None
):
if i + 1 == len(hidden_dims) and activate_final and final_activation is not None:
layers.append(
final_activation
if isinstance(final_activation, nn.Module)
@@ -436,9 +417,7 @@ class MLP(nn.Module):
)
else:
layers.append(
activations
if isinstance(activations, nn.Module)
else getattr(nn, activations)()
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
)
self.net = nn.Sequential(*layers)
@@ -639,15 +618,11 @@ class Policy(nn.Module):
# Compute standard deviations
if self.fixed_std is None:
log_std = self.std_layer(outputs)
assert not torch.isnan(log_std).any(), (
"[ERROR] log_std became NaN after std_layer!"
)
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
if self.use_tanh_squash:
log_std = torch.tanh(log_std)
log_std = self.log_std_min + 0.5 * (
self.log_std_max - self.log_std_min
) * (log_std + 1.0)
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
else:
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
else:
@@ -660,9 +635,7 @@ class Policy(nn.Module):
if self.use_tanh_squash:
actions = torch.tanh(x_t)
log_probs -= torch.log(
(1 - actions.pow(2)) + 1e-6
) # Adjust log-probs for Tanh
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
else:
actions = x_t # No Tanh; raw Gaussian sample
@@ -709,9 +682,7 @@ class SACObservationEncoder(nn.Module):
freeze_image_encoder(self.image_enc_layers)
else:
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
self.all_image_keys = [
k for k in config.input_shapes if k.startswith("observation.image")
]
self.all_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential(
@@ -738,9 +709,7 @@ class SACObservationEncoder(nn.Module):
self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
self.aggregation_layer = nn.Linear(
in_features=self.aggregation_size, out_features=config.latent_dim
)
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
@@ -753,19 +722,13 @@ class SACObservationEncoder(nn.Module):
obs_dict = self.input_normalization(obs_dict)
# Batch all images along the batch dimension, then encode them.
if len(self.all_image_keys) > 0:
images_batched = torch.cat(
[obs_dict[key] for key in self.all_image_keys], dim=0
)
images_batched = torch.cat([obs_dict[key] for key in self.all_image_keys], dim=0)
images_batched = self.image_enc_layers(images_batched)
embeddings_chunks = torch.chunk(
images_batched, dim=0, chunks=len(self.all_image_keys)
)
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
feat.extend(embeddings_chunks)
if "observation.environment_state" in self.config.input_shapes:
feat.append(
self.env_state_enc_layers(obs_dict["observation.environment_state"])
)
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
@@ -833,9 +796,7 @@ class PretrainedImageEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.image_enc_layers, self.image_enc_out_shape = (
self._load_pretrained_vision_encoder(config)
)
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
self.image_enc_proj = nn.Sequential(
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
@@ -846,21 +807,15 @@ class PretrainedImageEncoder(nn.Module):
"""Set up CNN encoder"""
from transformers import AutoModel
self.image_enc_layers = AutoModel.from_pretrained(
config.vision_encoder_name, trust_remote_code=True
)
self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True)
# self.image_enc_layers.pooler = Identity()
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[
-1
] # Last channel dimension
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
elif hasattr(self.image_enc_layers, "fc"):
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
else:
raise ValueError(
"Unsupported vision encoder architecture, make sure you are using a CNN"
)
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN")
return self.image_enc_layers, self.image_enc_out_shape
def forward(self, x):
@@ -896,9 +851,7 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
for key, value in inner_dict.items():
converted_params[outer_key][key] = torch.tensor(value)
if "image" in outer_key:
converted_params[outer_key][key] = converted_params[outer_key][
key
].view(3, 1, 1)
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
return converted_params

View File

@@ -183,13 +183,9 @@ class TDMPCConfig(PreTrainedConfig):
"If `n_action_steps > 1`, `n_action_repeats` must be left to its default value of 1."
)
if not self.use_mpc:
raise ValueError(
"If `n_action_steps > 1`, `use_mpc` must be set to `True`."
)
raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.")
if self.n_action_steps > self.horizon:
raise ValueError(
"`n_action_steps` must be less than or equal to `horizon`."
)
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")
def get_optimizer_preset(self) -> AdamConfig:
return AdamConfig(lr=self.optimizer_lr)
@@ -209,9 +205,7 @@ class TDMPCConfig(PreTrainedConfig):
if image_ft.shape[-2] != image_ft.shape[-1]:
# TODO(alexander-soare): This limitation is solely because of code in the random shift
# augmentation. It should be able to be removed.
raise ValueError(
f"Only square images are handled now. Got image shape {image_ft.shape}."
)
raise ValueError(f"Only square images are handled now. Got image shape {image_ft.shape}.")
@property
def observation_delta_indices(self) -> list:

View File

@@ -83,9 +83,7 @@ class TDMPCPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(
config.input_features, config.normalization_mapping, dataset_stats
)
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
@@ -110,9 +108,7 @@ class TDMPCPolicy(PreTrainedPolicy):
"""
self._queues = {
"observation.state": deque(maxlen=1),
"action": deque(
maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)
),
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
}
if self.config.image_features:
self._queues["observation.image"] = deque(maxlen=1)
@@ -127,9 +123,7 @@ class TDMPCPolicy(PreTrainedPolicy):
"""Select a single action given environment observations."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[next(iter(self.config.image_features))]
self._queues = populate_queues(self._queues, batch)
@@ -232,47 +226,35 @@ class TDMPCPolicy(PreTrainedPolicy):
self.config.action_feature.shape[0],
device=std.device,
)
gaussian_actions = torch.clamp(
mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1
)
gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1)
# Compute elite actions.
actions = torch.cat([gaussian_actions, pi_actions], dim=1)
value = self.estimate_value(z, actions).nan_to_num_(0)
elite_idxs = torch.topk(
value, self.config.n_elites, dim=0
).indices # (n_elites, batch)
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices # (n_elites, batch)
elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch)
# (horizon, n_elites, batch, action_dim)
elite_actions = actions.take_along_dim(
einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1
)
elite_actions = actions.take_along_dim(einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1)
# Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites.
max_value = elite_value.max(0, keepdim=True)[0] # (1, batch)
# The weighting is a softmax over trajectory values. Note that this is not the same as the usage
# of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This
# makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²).
score = torch.exp(
self.config.elite_weighting_temperature * (elite_value - max_value)
)
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
score /= score.sum(axis=0, keepdim=True)
# (horizon, batch, action_dim)
_mean = torch.sum(
einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1
)
_mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1)
_std = torch.sqrt(
torch.sum(
einops.rearrange(score, "n b -> n b 1")
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d"))
** 2,
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2,
dim=1,
)
)
# Update mean with an exponential moving average, and std with a direct replacement.
mean = (
self.config.gaussian_mean_momentum * mean
+ (1 - self.config.gaussian_mean_momentum) * _mean
self.config.gaussian_mean_momentum * mean + (1 - self.config.gaussian_mean_momentum) * _mean
)
std = _std.clamp_(self.config.min_std, self.config.max_std)
@@ -281,9 +263,7 @@ class TDMPCPolicy(PreTrainedPolicy):
# Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax
# scores from the last iteration.
actions = elite_actions[
:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)
]
actions = elite_actions[:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)]
return actions
@@ -306,8 +286,7 @@ class TDMPCPolicy(PreTrainedPolicy):
# of the FOWM paper.
if self.config.uncertainty_regularizer_coeff > 0:
regularization = -(
self.config.uncertainty_regularizer_coeff
* self.model.Qs(z, actions[t]).std(0)
self.config.uncertainty_regularizer_coeff * self.model.Qs(z, actions[t]).std(0)
)
else:
regularization = 0
@@ -328,9 +307,7 @@ class TDMPCPolicy(PreTrainedPolicy):
G += (
running_discount
* torch.min(
terminal_values[
torch.randint(0, self.config.q_ensemble_size, size=(2,))
],
terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))],
dim=0,
)[0]
)
@@ -338,11 +315,7 @@ class TDMPCPolicy(PreTrainedPolicy):
G += running_discount * torch.min(terminal_values, dim=0)[0]
# Finally, also regularize the terminal value.
if self.config.uncertainty_regularizer_coeff > 0:
G -= (
running_discount
* self.config.uncertainty_regularizer_coeff
* terminal_values.std(0)
)
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
return G
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
@@ -354,9 +327,7 @@ class TDMPCPolicy(PreTrainedPolicy):
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[next(iter(self.config.image_features))]
batch = self.normalize_targets(batch)
@@ -388,29 +359,21 @@ class TDMPCPolicy(PreTrainedPolicy):
current_observation[k] = observations[k][0]
next_observations[k] = observations[k][1:]
horizon, batch_size = next_observations[
"observation.image"
if self.config.image_features
else "observation.environment_state"
"observation.image" if self.config.image_features else "observation.environment_state"
].shape[:2]
# Run latent rollout using the latent dynamics model and policy model.
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
# gives us a next `z`.
batch_size = batch["index"].shape[0]
z_preds = torch.empty(
horizon + 1, batch_size, self.config.latent_dim, device=device
)
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
z_preds[0] = self.model.encode(current_observation)
reward_preds = torch.empty_like(reward, device=device)
for t in range(horizon):
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(
z_preds[t], action[t]
)
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t])
# Compute Q and V value predictions based on the latent rollout.
q_preds_ensemble = self.model.Qs(
z_preds[:-1], action
) # (ensemble, horizon, batch)
q_preds_ensemble = self.model.Qs(z_preds[:-1], action) # (ensemble, horizon, batch)
v_preds = self.model.V(z_preds[:-1])
info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()})
@@ -424,14 +387,10 @@ class TDMPCPolicy(PreTrainedPolicy):
# actions (not actions estimated by π).
# Note: Here we do not use self.model_target, but self.model. This is to follow the original code
# and the FOWM paper.
q_targets = reward + self.config.discount * self.model.V(
self.model.encode(next_observations)
)
q_targets = reward + self.config.discount * self.model.V(self.model.encode(next_observations))
# From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we
# are using them to compute loss for V.
v_targets = self.model_target.Qs(
z_preds[:-1].detach(), action, return_min=True
)
v_targets = self.model_target.Qs(z_preds[:-1].detach(), action, return_min=True)
# Compute losses.
# Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the
@@ -474,9 +433,7 @@ class TDMPCPolicy(PreTrainedPolicy):
temporal_loss_coeffs
* F.mse_loss(
q_preds_ensemble,
einops.repeat(
q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]
),
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
reduction="none",
).sum(0) # sum over ensemble
# `q_preds_ensemble` depends on the first observation and the actions.
@@ -514,14 +471,12 @@ class TDMPCPolicy(PreTrainedPolicy):
z_preds = z_preds.detach()
# Use stopgrad for the advantage calculation.
with torch.no_grad():
advantage = self.model_target.Qs(
z_preds[:-1], action, return_min=True
) - self.model.V(z_preds[:-1])
advantage = self.model_target.Qs(z_preds[:-1], action, return_min=True) - self.model.V(
z_preds[:-1]
)
info["advantage"] = advantage[0]
# (t, b)
exp_advantage = torch.clamp(
torch.exp(advantage * self.config.advantage_scaling), max=100.0
)
exp_advantage = torch.clamp(torch.exp(advantage * self.config.advantage_scaling), max=100.0)
action_preds = self.model.pi(z_preds[:-1]) # (t, b, a)
# Calculate the MSE between the actions and the action predictions.
# Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation
@@ -575,9 +530,7 @@ class TDMPCPolicy(PreTrainedPolicy):
# Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA
# update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code
# we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995)
update_ema_parameters(
self.model_target, self.model, self.config.target_model_momentum
)
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
class TDMPCTOLD(nn.Module):
@@ -588,9 +541,7 @@ class TDMPCTOLD(nn.Module):
self.config = config
self._encoder = TDMPCObservationEncoder(config)
self._dynamics = nn.Sequential(
nn.Linear(
config.latent_dim + config.action_feature.shape[0], config.mlp_dim
),
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim),
@@ -601,9 +552,7 @@ class TDMPCTOLD(nn.Module):
nn.Sigmoid(),
)
self._reward = nn.Sequential(
nn.Linear(
config.latent_dim + config.action_feature.shape[0], config.mlp_dim
),
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim),
@@ -671,9 +620,7 @@ class TDMPCTOLD(nn.Module):
"Sanity check. The last linear layer needs 0 initialization on weights."
)
nn.init.zeros_(m[-1].weight)
nn.init.zeros_(
m[-1].bias
) # this has already been done, but keep this line here for good measure
nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure
def encode(self, obs: dict[str, Tensor]) -> Tensor:
"""Encodes an observation into its latent representation."""
@@ -812,9 +759,7 @@ class TDMPCObservationEncoder(nn.Module):
if config.robot_state_feature:
self.state_enc_layers = nn.Sequential(
nn.Linear(
config.robot_state_feature.shape[0], config.state_encoder_hidden_dim
),
nn.Linear(config.robot_state_feature.shape[0], config.state_encoder_hidden_dim),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim),
@@ -823,9 +768,7 @@ class TDMPCObservationEncoder(nn.Module):
if config.env_state_feature:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
config.env_state_feature.shape[0], config.state_encoder_hidden_dim
),
nn.Linear(config.env_state_feature.shape[0], config.state_encoder_hidden_dim),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim),
@@ -898,10 +841,7 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
assert n_p_ema == n_p, "Parameter names don't match for EMA model update"
if isinstance(p, dict):
raise RuntimeError("Dict parameter not supported")
if (
isinstance(module, nn.modules.batchnorm._BatchNorm)
or not p.requires_grad
):
if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad:
# Copy BatchNorm parameters, and non-trainable parameters directly.
p_ema.copy_(p.to(dtype=p_ema.dtype).data)
with torch.no_grad():
@@ -909,9 +849,7 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha)
def flatten_forward_unflatten(
fn: Callable[[Tensor], Tensor], image_tensor: Tensor
) -> Tensor:
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor.
Args:

View File

@@ -172,10 +172,7 @@ class VQBeTConfig(PreTrainedConfig):
if self.crop_shape is not None:
for key, image_ft in self.image_features.items():
if (
self.crop_shape[0] > image_ft.shape[1]
or self.crop_shape[1] > image_ft.shape[2]
):
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
raise ValueError(
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for "

View File

@@ -64,9 +64,7 @@ class VQBeTPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(
config.input_features, config.normalization_mapping, dataset_stats
)
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
@@ -97,17 +95,11 @@ class VQBeTPolicy(PreTrainedPolicy):
if self.config.sequentially_select:
decay_params = (
decay_params
+ list(
self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters()
)
+ list(
self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters()
)
+ list(self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
+ list(self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
)
else:
decay_params = decay_params + list(
self.vqbet.action_head.map_to_cbet_preds_bin.parameters()
)
decay_params = decay_params + list(self.vqbet.action_head.map_to_cbet_preds_bin.parameters())
return [
{
@@ -145,12 +137,8 @@ class VQBeTPolicy(PreTrainedPolicy):
"""
batch = self.normalize_inputs(batch)
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)
@@ -161,14 +149,8 @@ class VQBeTPolicy(PreTrainedPolicy):
)
if len(self._queues["action"]) == 0:
batch = {
k: torch.stack(list(self._queues[k]), dim=1)
for k in batch
if k in self._queues
}
actions = self.vqbet(batch, rollout=True)[
:, : self.config.action_chunk_size
]
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
# the dimension of returned action is (batch_size, action_chunk_size, action_dim)
actions = self.unnormalize_outputs({"action": actions})["action"]
@@ -181,12 +163,8 @@ class VQBeTPolicy(PreTrainedPolicy):
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
batch = self.normalize_targets(batch)
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
if not self.vqbet.action_head.vqvae_model.discretized.item():
@@ -194,9 +172,7 @@ class VQBeTPolicy(PreTrainedPolicy):
# n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`.
# n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree).
loss, n_different_codes, n_different_combinations, recon_l1_error = (
self.vqbet.action_head.discretize(
self.config.n_vqvae_training_steps, batch["action"]
)
self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"])
)
return loss, {
"n_different_codes": n_different_codes,
@@ -253,9 +229,7 @@ class SpatialSoftmax(nn.Module):
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
# and causes a small degradation in pc_success of pre-trained models.
pos_x, pos_y = np.meshgrid(
np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)
)
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
# register as buffer so it's moved to the correct device.
@@ -370,12 +344,7 @@ class VQBeTModel(nn.Module):
num_tokens = self.config.n_action_pred_token + self.config.n_obs_steps - 1
self.register_buffer(
"select_target_actions_indices",
torch.row_stack(
[
torch.arange(i, i + self.config.action_chunk_size)
for i in range(num_tokens)
]
),
torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),
)
def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]:
@@ -406,19 +375,13 @@ class VQBeTModel(nn.Module):
input_tokens.append(
self.state_projector(batch["observation.state"])
) # (batch, obs_step, projection dims)
input_tokens.append(
einops.repeat(
self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps
)
)
input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps))
# Interleave tokens by stacking and rearranging.
input_tokens = torch.stack(input_tokens, dim=2)
input_tokens = einops.rearrange(input_tokens, "b n t d -> b (n t) d")
len_additional_action_token = self.config.n_action_pred_token - 1
future_action_tokens = self.action_token.repeat(
batch_size, len_additional_action_token, 1
)
future_action_tokens = self.action_token.repeat(batch_size, len_additional_action_token, 1)
# add additional action query tokens for predicting future action chunks
input_tokens = torch.cat([input_tokens, future_action_tokens], dim=1)
@@ -427,9 +390,9 @@ class VQBeTModel(nn.Module):
features = self.policy(input_tokens)
# len(self.config.input_features) is the number of different observation modes.
# this line gets the index of action prompt tokens.
historical_act_pred_index = np.arange(0, n_obs_steps) * (
len(self.config.input_features) + 1
) + len(self.config.input_features)
historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_features) + 1) + len(
self.config.input_features
)
# only extract the output tokens at the position of action query:
# Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models,
@@ -449,15 +412,13 @@ class VQBeTModel(nn.Module):
action_head_output = self.action_head(features)
# if rollout, VQ-BeT don't calculate loss
if rollout:
return action_head_output["predicted_action"][
:, n_obs_steps - 1, :
].reshape(batch_size, self.config.action_chunk_size, -1)
return action_head_output["predicted_action"][:, n_obs_steps - 1, :].reshape(
batch_size, self.config.action_chunk_size, -1
)
# else, it calculate overall loss (bin prediction loss, and offset loss)
else:
output = batch["action"][:, self.select_target_actions_indices]
loss = self.action_head.loss_fn(
action_head_output, output, reduction="mean"
)
loss = self.action_head.loss_fn(action_head_output, output, reduction="mean")
return action_head_output, loss
@@ -492,9 +453,7 @@ class VQBeTHead(nn.Module):
else:
self.map_to_cbet_preds_bin = MLP(
in_channels=config.gpt_output_dim,
hidden_channels=[
self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed
],
hidden_channels=[self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed],
)
self.map_to_cbet_preds_offset = MLP(
in_channels=config.gpt_output_dim,
@@ -521,10 +480,7 @@ class VQBeTHead(nn.Module):
loss, metric = self.vqvae_model.vqvae_forward(actions)
n_different_codes = sum(
[
len(torch.unique(metric[2][:, i]))
for i in range(self.vqvae_model.vqvae_num_layers)
]
[len(torch.unique(metric[2][:, i])) for i in range(self.vqvae_model.vqvae_num_layers)]
)
n_different_combinations = len(torch.unique(metric[2], dim=0))
recon_l1_error = metric[0].detach().cpu().item()
@@ -585,18 +541,12 @@ class VQBeTHead(nn.Module):
cbet_secondary_logits / self.config.bet_softmax_temperature, dim=-1
)
sampled_secondary_centers = einops.rearrange(
torch.multinomial(
cbet_secondary_probs.view(-1, choices), num_samples=1
),
torch.multinomial(cbet_secondary_probs.view(-1, choices), num_samples=1),
"(NT) 1 -> NT",
NT=NT,
)
sampled_centers = torch.stack(
(sampled_primary_centers, sampled_secondary_centers), axis=1
)
cbet_logits = torch.stack(
[cbet_primary_logits, cbet_secondary_logits], dim=1
)
sampled_centers = torch.stack((sampled_primary_centers, sampled_secondary_centers), axis=1)
cbet_logits = torch.stack([cbet_primary_logits, cbet_secondary_logits], dim=1)
# if self.config.sequentially_select is False, bin prediction head samples primary and secondary code at once.
else:
cbet_logits = self.map_to_cbet_preds_bin(x)
@@ -605,9 +555,7 @@ class VQBeTHead(nn.Module):
"(NT) (G C) -> (NT) G C",
G=self.vqvae_model.vqvae_num_layers,
)
cbet_probs = torch.softmax(
cbet_logits / self.config.bet_softmax_temperature, dim=-1
)
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
NT, G, choices = cbet_probs.shape
sampled_centers = einops.rearrange(
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
@@ -627,17 +575,9 @@ class VQBeTHead(nn.Module):
sampled_offsets = sampled_offsets.sum(dim=1)
with torch.no_grad():
# Get the centroids (= vectors corresponding to the codes) of each layer to pass it through RVQ decoder
return_decoder_input = (
self.vqvae_model.get_embeddings_from_code(sampled_centers)
.clone()
.detach()
)
return_decoder_input = self.vqvae_model.get_embeddings_from_code(sampled_centers).clone().detach()
# pass the centroids through decoder to get actions.
decoded_action = (
self.vqvae_model.get_action_from_latent(return_decoder_input)
.clone()
.detach()
)
decoded_action = self.vqvae_model.get_action_from_latent(return_decoder_input).clone().detach()
# reshaped extracted offset to match with decoded centroids
sampled_offsets = einops.rearrange(
sampled_offsets, "NT (W A) -> NT W A", W=self.config.action_chunk_size
@@ -686,9 +626,7 @@ class VQBeTHead(nn.Module):
# Figure out the loss for the actions.
# First, we need to find the closest cluster center for each ground truth action.
with torch.no_grad():
state_vq, action_bins = self.vqvae_model.get_code(
action_seq
) # action_bins: NT, G
state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
# Now we can compute the loss.
@@ -711,12 +649,8 @@ class VQBeTHead(nn.Module):
+ cbet_loss2 * self.config.secondary_code_loss_weight
)
equal_primary_code_rate = torch.sum(
(action_bins[:, 0] == sampled_centers[:, 0]).int()
) / (NT)
equal_secondary_code_rate = torch.sum(
(action_bins[:, 1] == sampled_centers[:, 1]).int()
) / (NT)
equal_primary_code_rate = torch.sum((action_bins[:, 0] == sampled_centers[:, 0]).int()) / (NT)
equal_secondary_code_rate = torch.sum((action_bins[:, 1] == sampled_centers[:, 1]).int()) / (NT)
action_mse_error = torch.mean((action_seq - predicted_action) ** 2)
vq_action_error = torch.mean(torch.abs(action_seq - decoded_action))
@@ -730,9 +664,7 @@ class VQBeTHead(nn.Module):
"classification_loss": cbet_loss.detach().cpu().item(),
"offset_loss": offset_loss.detach().cpu().item(),
"equal_primary_code_rate": equal_primary_code_rate.detach().cpu().item(),
"equal_secondary_code_rate": equal_secondary_code_rate.detach()
.cpu()
.item(),
"equal_secondary_code_rate": equal_secondary_code_rate.detach().cpu().item(),
"vq_action_error": vq_action_error.detach().cpu().item(),
"offset_action_error": offset_action_error.detach().cpu().item(),
"action_error_max": action_error_max.detach().cpu().item(),
@@ -757,9 +689,7 @@ class VQBeTRgbEncoder(nn.Module):
# Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(
config.crop_shape
)
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
@@ -780,9 +710,7 @@ class VQBeTRgbEncoder(nn.Module):
self.backbone = _replace_submodules(
root_module=self.backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features // 16, num_channels=x.num_features
),
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
)
# Set up pooling and final layers.
@@ -792,15 +720,11 @@ class VQBeTRgbEncoder(nn.Module):
# height and width from `config.image_features`.
images_shape = next(iter(config.image_features.values())).shape
dummy_shape_h_w = (
config.crop_shape if config.crop_shape is not None else images_shape[1:]
)
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
self.pool = SpatialSoftmax(
feature_map_shape, num_kp=config.spatial_softmax_num_keypoints
)
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU()
@@ -842,11 +766,7 @@ def _replace_submodules(
if predicate(root_module):
return func(root_module)
replace_list = [
k.split(".")
for k, m in root_module.named_modules(remove_duplicate=True)
if predicate(m)
]
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
for *parents, k in replace_list:
parent_module = root_module
if len(parents) > 0:
@@ -861,9 +781,7 @@ def _replace_submodules(
else:
setattr(parent_module, k, tgt_module)
# verify that all BN are replaced
assert not any(
predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)
)
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
return root_module
@@ -896,8 +814,7 @@ class VqVae(nn.Module):
)
self.encoder = MLP(
in_channels=self.config.action_feature.shape[0]
* self.config.action_chunk_size,
in_channels=self.config.action_feature.shape[0] * self.config.action_chunk_size,
hidden_channels=[
config.vqvae_enc_hidden_dim,
config.vqvae_enc_hidden_dim,
@@ -925,13 +842,9 @@ class VqVae(nn.Module):
# given latent vector, this function outputs the decoded action.
output = self.decoder(latent)
if self.config.action_chunk_size == 1:
return einops.rearrange(
output, "N (T A) -> N T A", A=self.config.action_feature.shape[0]
)
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
else:
return einops.rearrange(
output, "N (T A) -> N T A", A=self.config.action_feature.shape[0]
)
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
def get_code(self, state):
# in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181)

View File

@@ -123,15 +123,9 @@ class CausalSelfAttention(nn.Module):
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.gpt_hidden_dim, dim=2)
k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
1, 2
) # (B, nh, T, hs)
q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
1, 2
) # (B, nh, T, hs)
v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
1, 2
) # (B, nh, T, hs)
k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
@@ -139,9 +133,7 @@ class CausalSelfAttention(nn.Module):
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = (
y.transpose(1, 2).contiguous().view(B, T, C)
) # re-assemble all head outputs side by side
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.resid_dropout(self.c_proj(y))
@@ -197,16 +189,12 @@ class GPT(nn.Module):
"ln_f": nn.LayerNorm(config.gpt_hidden_dim),
}
)
self.lm_head = nn.Linear(
config.gpt_hidden_dim, config.gpt_output_dim, bias=False
)
self.lm_head = nn.Linear(config.gpt_hidden_dim, config.gpt_output_dim, bias=False)
# init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
self.apply(self._init_weights)
for pn, p in self.named_parameters():
if pn.endswith("c_proj.weight"):
torch.nn.init.normal_(
p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer)
)
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer))
# report number of parameters
n_params = sum(p.numel() for p in self.parameters())
@@ -220,17 +208,11 @@ class GPT(nn.Module):
)
# positional encodings that are added to the input embeddings
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(
0
) # shape (1, t)
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
# forward the GPT model itself
tok_emb = self.transformer.wte(
input
) # token embeddings of shape (b, t, gpt_hidden_dim)
pos_emb = self.transformer.wpe(
pos
) # position embeddings of shape (1, t, gpt_hidden_dim)
tok_emb = self.transformer.wte(input) # token embeddings of shape (b, t, gpt_hidden_dim)
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim)
x = self.transformer.drop(tok_emb + pos_emb)
for block in self.transformer.h:
x = block(x)
@@ -255,9 +237,7 @@ class GPT(nn.Module):
# but want to use a smaller block size for some smaller, simpler model
assert gpt_block_size <= self.config.gpt_block_size
self.config.gpt_block_size = gpt_block_size
self.transformer.wpe.weight = nn.Parameter(
self.transformer.wpe.weight[:gpt_block_size]
)
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:gpt_block_size])
for block in self.transformer.h:
block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size]
@@ -290,10 +270,8 @@ class GPT(nn.Module):
param_dict = dict(self.named_parameters())
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, (
"parameters {} made it into both decay/no_decay sets!".format(
str(inter_params)
)
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
str(inter_params)
)
assert len(param_dict.keys() - union_params) == 0, (
"parameters {} were not separated into either decay/no_decay set!".format(
@@ -390,12 +368,8 @@ class ResidualVQ(nn.Module):
codebook_input_dim = codebook_dim * heads
requires_projection = codebook_input_dim != dim
self.project_in = (
nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
)
self.project_out = (
nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
)
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
self.num_quantizers = num_quantizers
@@ -477,9 +451,7 @@ class ResidualVQ(nn.Module):
return all_codes
def forward(
self, x, indices=None, return_all_codes=False, sample_codebook_temp=None
):
def forward(self, x, indices=None, return_all_codes=False, sample_codebook_temp=None):
"""
For given input tensor x, this function will return the quantized output, the indices of the quantized output, and the loss.
First, the input tensor x is projected to the codebook dimension. Then, the input tensor x is passed through Nq layers of VectorQuantize.
@@ -508,17 +480,13 @@ class ResidualVQ(nn.Module):
)
ce_losses = []
should_quantize_dropout = (
self.training and self.quantize_dropout and not return_loss
)
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
# sample a layer index at which to dropout further residual quantization
# also prepare null indices and loss
if should_quantize_dropout:
rand_quantize_dropout_index = randrange(
self.quantize_dropout_cutoff_index, num_quant
)
rand_quantize_dropout_index = randrange(self.quantize_dropout_cutoff_index, num_quant)
if quant_dropout_multiple_of != 1:
rand_quantize_dropout_index = (
@@ -527,23 +495,14 @@ class ResidualVQ(nn.Module):
- 1
)
null_indices_shape = (
(x.shape[0], *x.shape[-2:])
if self.accept_image_fmap
else tuple(x.shape[:2])
)
null_indices = torch.full(
null_indices_shape, -1.0, device=device, dtype=torch.long
)
null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2])
null_indices = torch.full(null_indices_shape, -1.0, device=device, dtype=torch.long)
null_loss = torch.full((1,), 0.0, device=device, dtype=x.dtype)
# go through the layers
for quantizer_index, layer in enumerate(self.layers):
if (
should_quantize_dropout
and quantizer_index > rand_quantize_dropout_index
):
if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
all_indices.append(null_indices)
all_losses.append(null_loss)
continue
@@ -583,9 +542,7 @@ class ResidualVQ(nn.Module):
# stack all losses and indices
all_losses, all_indices = map(
partial(torch.stack, dim=-1), (all_losses, all_indices)
)
all_losses, all_indices = map(partial(torch.stack, dim=-1), (all_losses, all_indices))
ret = (quantized_out, all_indices, all_losses)
@@ -645,12 +602,8 @@ class VectorQuantize(nn.Module):
codebook_input_dim = codebook_dim * heads
requires_projection = codebook_input_dim != dim
self.project_in = (
nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
)
self.project_out = (
nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
)
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
self.eps = eps
self.commitment_weight = commitment_weight
@@ -664,14 +617,10 @@ class VectorQuantize(nn.Module):
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
assert not (ema_update and learnable_codebook), (
"learnable codebook not compatible with EMA update"
)
assert not (ema_update and learnable_codebook), "learnable codebook not compatible with EMA update"
assert 0 <= sync_update_v <= 1.0
assert not (sync_update_v > 0.0 and not learnable_codebook), (
"learnable codebook must be turned on"
)
assert not (sync_update_v > 0.0 and not learnable_codebook), "learnable codebook must be turned on"
self.sync_update_v = sync_update_v
@@ -683,9 +632,7 @@ class VectorQuantize(nn.Module):
)
if sync_codebook is None:
sync_codebook = (
distributed.is_initialized() and distributed.get_world_size() > 1
)
sync_codebook = distributed.is_initialized() and distributed.get_world_size() > 1
codebook_kwargs = {
"dim": codebook_dim,
@@ -850,17 +797,11 @@ class VectorQuantize(nn.Module):
# quantize again
quantize, embed_ind, distances = self._codebook(
x, **codebook_forward_kwargs
)
quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)
if self.training:
# determine code to use for commitment loss
maybe_detach = (
torch.detach
if not self.learnable_codebook or freeze_codebook
else identity
)
maybe_detach = torch.detach if not self.learnable_codebook or freeze_codebook else identity
commit_quantize = maybe_detach(quantize)
@@ -870,9 +811,7 @@ class VectorQuantize(nn.Module):
if self.sync_update_v > 0.0:
# (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
quantize = quantize + self.sync_update_v * (
quantize - quantize.detach()
)
quantize = quantize + self.sync_update_v * (quantize - quantize.detach())
# function for calculating cross entropy loss to distance matrix
# used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss
@@ -905,9 +844,7 @@ class VectorQuantize(nn.Module):
embed_ind = rearrange(embed_ind, "1 (b h) n -> b n h", h=heads)
if self.accept_image_fmap:
embed_ind = rearrange(
embed_ind, "b (h w) ... -> b h w ...", h=height, w=width
)
embed_ind = rearrange(embed_ind, "b (h w) ... -> b h w ...", h=height, w=width)
if only_one:
embed_ind = rearrange(embed_ind, "b 1 -> b")
@@ -961,12 +898,8 @@ class VectorQuantize(nn.Module):
num_codes = codebook.shape[-2]
if (
self.orthogonal_reg_max_codes is not None
) and num_codes > self.orthogonal_reg_max_codes:
rand_ids = torch.randperm(num_codes, device=device)[
: self.orthogonal_reg_max_codes
]
if (self.orthogonal_reg_max_codes is not None) and num_codes > self.orthogonal_reg_max_codes:
rand_ids = torch.randperm(num_codes, device=device)[: self.orthogonal_reg_max_codes]
codebook = codebook[:, rand_ids]
orthogonal_reg_loss = orthogonal_loss_fn(codebook)
@@ -998,9 +931,7 @@ class VectorQuantize(nn.Module):
# if masking, only return quantized for where mask has True
if mask is not None:
quantize = torch.where(
rearrange(mask, "... -> ... 1"), quantize, orig_input
)
quantize = torch.where(rearrange(mask, "... -> ... 1"), quantize, orig_input)
return quantize, embed_ind, loss
@@ -1110,9 +1041,7 @@ def sample_vectors(samples, num):
def batched_sample_vectors(samples, num):
return torch.stack(
[sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0
)
return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0)
def pad_shape(shape, size, dim=0):
@@ -1163,9 +1092,7 @@ def sample_vectors_distributed(local_samples, num):
all_num_samples = all_gather_sizes(local_samples, dim=0)
if rank == 0:
samples_per_rank = sample_multinomial(
num, all_num_samples / all_num_samples.sum()
)
samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum())
else:
samples_per_rank = torch.empty_like(all_num_samples)
@@ -1278,9 +1205,7 @@ class EuclideanCodebook(nn.Module):
self.eps = eps
self.threshold_ema_dead_code = threshold_ema_dead_code
self.reset_cluster_size = (
reset_cluster_size
if (reset_cluster_size is not None)
else threshold_ema_dead_code
reset_cluster_size if (reset_cluster_size is not None) else threshold_ema_dead_code
)
assert callable(gumbel_sample)
@@ -1291,14 +1216,8 @@ class EuclideanCodebook(nn.Module):
"kmeans init is not compatible with multiple codebooks in distributed environment for now"
)
self.sample_fn = (
sample_vectors_distributed
if use_ddp and sync_kmeans
else batched_sample_vectors
)
self.kmeans_all_reduce_fn = (
distributed.all_reduce if use_ddp and sync_kmeans else noop
)
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
self.register_buffer("initted", torch.Tensor([not kmeans_init]))
@@ -1437,9 +1356,7 @@ class EuclideanCodebook(nn.Module):
distributed.all_reduce(variance_number)
batch_variance = variance_number / num_vectors
self.update_with_decay(
"batch_variance", batch_variance, self.affine_param_batch_decay
)
self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay)
def replace(self, batch_samples, batch_mask):
for ind, (samples, mask) in enumerate(
@@ -1448,9 +1365,7 @@ class EuclideanCodebook(nn.Module):
if not torch.any(mask):
continue
sampled = self.sample_fn(
rearrange(samples, "... -> 1 ..."), mask.sum().item()
)
sampled = self.sample_fn(rearrange(samples, "... -> 1 ..."), mask.sum().item())
sampled = rearrange(sampled, "1 ... -> ...")
self.embed.data[ind][mask] = sampled
@@ -1474,9 +1389,7 @@ class EuclideanCodebook(nn.Module):
def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
needs_codebook_dim = x.ndim < 4
sample_codebook_temp = (
sample_codebook_temp
if (sample_codebook_temp is not None)
else self.sample_codebook_temp
sample_codebook_temp if (sample_codebook_temp is not None) else self.sample_codebook_temp
)
x = x.float()
@@ -1504,9 +1417,7 @@ class EuclideanCodebook(nn.Module):
if self.affine_param:
codebook_std = self.codebook_variance.clamp(min=1e-5).sqrt()
batch_std = self.batch_variance.clamp(min=1e-5).sqrt()
embed = (embed - self.codebook_mean) * (
batch_std / codebook_std
) + self.batch_mean
embed = (embed - self.codebook_mean) * (batch_std / codebook_std) + self.batch_mean
dist = -cdist(flatten, embed)
@@ -1524,9 +1435,7 @@ class EuclideanCodebook(nn.Module):
if self.training and self.ema_update and not freeze_codebook:
if self.affine_param:
flatten = (flatten - self.batch_mean) * (
codebook_std / batch_std
) + self.codebook_mean
flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean
if mask is not None:
embed_onehot[~mask] = 0.0
@@ -1549,9 +1458,7 @@ class EuclideanCodebook(nn.Module):
self.expire_codes_(x)
if needs_codebook_dim:
quantize, embed_ind = tuple(
rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind)
)
quantize, embed_ind = tuple(rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind))
dist = unpack_one(dist, ps, "h * d")

View File

@@ -57,9 +57,7 @@ class OpenCVCameraConfig(CameraConfig):
self.channels = 3
if self.rotation not in [-90, None, 90, 180]:
raise ValueError(
f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})"
)
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
@CameraConfig.register_subclass("intelrealsense")
@@ -104,12 +102,8 @@ class IntelRealSenseCameraConfig(CameraConfig):
self.channels = 3
at_least_one_is_not_none = (
self.fps is not None or self.width is not None or self.height is not None
)
at_least_one_is_none = (
self.fps is None or self.width is None or self.height is None
)
at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None
at_least_one_is_none = self.fps is None or self.width is None or self.height is None
if at_least_one_is_not_none and at_least_one_is_none:
raise ValueError(
"For `fps`, `width` and `height`, either all of them need to be set, or none of them, "
@@ -117,6 +111,4 @@ class IntelRealSenseCameraConfig(CameraConfig):
)
if self.rotation not in [-90, None, 90, 180]:
raise ValueError(
f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})"
)
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")

View File

@@ -79,9 +79,7 @@ def save_image(img_array, serial_number, frame_index, images_dir):
img.save(str(path), quality=100)
logging.info(f"Saved image: {path}")
except Exception as e:
logging.error(
f"Failed to save image for camera {serial_number} frame {frame_index}: {e}"
)
logging.error(f"Failed to save image for camera {serial_number} frame {frame_index}: {e}")
def save_images_from_cameras(
@@ -159,9 +157,7 @@ def save_images_from_cameras(
if time.perf_counter() - start_time > record_time_s:
break
print(
f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}"
)
print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
frame_index += 1
finally:
@@ -279,9 +275,7 @@ class IntelRealSenseCamera:
f"Multiple {name} cameras have been detected. Please use their serial number to instantiate them."
)
name_to_serial_dict = {
cam["name"]: cam["serial_number"] for cam in camera_infos
}
name_to_serial_dict = {cam["name"]: cam["serial_number"] for cam in camera_infos}
cam_sn = name_to_serial_dict[name]
return cam_sn
@@ -353,9 +347,7 @@ class IntelRealSenseCamera:
actual_height = color_profile.height()
# Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
if self.fps is not None and not math.isclose(
self.fps, actual_fps, rel_tol=1e-3
):
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
# Using `OSError` since it's a broad that encompasses issues related to device communication
raise OSError(
f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}."
@@ -375,9 +367,7 @@ class IntelRealSenseCamera:
self.is_connected = True
def read(
self, temporary_color: str | None = None
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
def read(self, temporary_color: str | None = None) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
"""Read a frame from the camera returned in the format height x width x channels (e.g. 480 x 640 x 3)
of type `np.uint8`, contrarily to the pytorch format which is float channel first.
@@ -404,15 +394,11 @@ class IntelRealSenseCamera:
color_frame = frame.get_color_frame()
if not color_frame:
raise OSError(
f"Can't capture color image from IntelRealSenseCamera({self.serial_number})."
)
raise OSError(f"Can't capture color image from IntelRealSenseCamera({self.serial_number}).")
color_image = np.asanyarray(color_frame.get_data())
requested_color_mode = (
self.color_mode if temporary_color is None else temporary_color
)
requested_color_mode = self.color_mode if temporary_color is None else temporary_color
if requested_color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided."
@@ -440,9 +426,7 @@ class IntelRealSenseCamera:
if self.use_depth:
depth_frame = frame.get_depth_frame()
if not depth_frame:
raise OSError(
f"Can't capture depth image from IntelRealSenseCamera({self.serial_number})."
)
raise OSError(f"Can't capture depth image from IntelRealSenseCamera({self.serial_number}).")
depth_map = np.asanyarray(depth_frame.get_data())
@@ -484,9 +468,7 @@ class IntelRealSenseCamera:
# TODO(rcadene, aliberts): intelrealsense has diverged compared to opencv over here
num_tries += 1
time.sleep(1 / self.fps)
if num_tries > self.fps and (
self.thread.ident is None or not self.thread.is_alive()
):
if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()):
raise Exception(
"The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called."
)

View File

@@ -45,14 +45,10 @@ from lerobot.common.utils.utils import capture_timestamp_utc
MAX_OPENCV_INDEX = 60
def find_cameras(
raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False
) -> list[dict]:
def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
cameras = []
if platform.system() == "Linux":
print(
"Linux detected. Finding available camera indices through scanning '/dev/video*' ports"
)
print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports")
possible_ports = [str(port) for port in Path("/dev").glob("video*")]
ports = _find_cameras(possible_ports, mock=mock)
for port in ports:
@@ -144,9 +140,7 @@ def save_images_from_cameras(
print("Connecting cameras")
cameras = []
for cam_idx in camera_ids:
config = OpenCVCameraConfig(
camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock
)
config = OpenCVCameraConfig(camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock)
camera = OpenCVCamera(config)
camera.connect()
print(
@@ -186,9 +180,7 @@ def save_images_from_cameras(
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
print(
f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}"
)
print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
if time.perf_counter() - start_time > record_time_s:
break
@@ -245,16 +237,12 @@ class OpenCVCamera:
if platform.system() == "Linux":
if isinstance(self.camera_index, int):
self.port = Path(f"/dev/video{self.camera_index}")
elif isinstance(self.camera_index, str) and is_valid_unix_path(
self.camera_index
):
elif isinstance(self.camera_index, str) and is_valid_unix_path(self.camera_index):
self.port = Path(self.camera_index)
# Retrieve the camera index from a potentially symlinked path
self.camera_index = get_camera_index_from_unix_port(self.port)
else:
raise ValueError(
f"Please check the provided camera_index: {self.camera_index}"
)
raise ValueError(f"Please check the provided camera_index: {self.camera_index}")
# Store the raw (capture) resolution from the config.
self.capture_width = config.width
@@ -295,9 +283,7 @@ class OpenCVCamera:
def connect(self):
if self.is_connected:
raise RobotDeviceAlreadyConnectedError(
f"OpenCVCamera({self.camera_index}) is already connected."
)
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
if self.mock:
import tests.cameras.mock_cv2 as cv2
@@ -318,11 +304,7 @@ class OpenCVCamera:
else cv2.CAP_ANY
)
camera_idx = (
f"/dev/video{self.camera_index}"
if platform.system() == "Linux"
else self.camera_index
)
camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index
# First create a temporary camera trying to access `camera_index`,
# and verify it is a valid camera by calling `isOpened`.
tmp_camera = cv2.VideoCapture(camera_idx, backend)
@@ -362,9 +344,7 @@ class OpenCVCamera:
actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT)
# Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
if self.fps is not None and not math.isclose(
self.fps, actual_fps, rel_tol=1e-3
):
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
# Using `OSError` since it's a broad that encompasses issues related to device communication
raise OSError(
f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}."
@@ -406,9 +386,7 @@ class OpenCVCamera:
if not ret:
raise OSError(f"Can't capture color image from camera {self.camera_index}.")
requested_color_mode = (
self.color_mode if temporary_color_mode is None else temporary_color_mode
)
requested_color_mode = self.color_mode if temporary_color_mode is None else temporary_color_mode
if requested_color_mode not in ["rgb", "bgr"]:
raise ValueError(

View File

@@ -93,9 +93,7 @@ class RecordControlConfig(ControlConfig):
policy_path = parser.get_path_arg("control.policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("control.policy")
self.policy = PreTrainedConfig.from_pretrained(
policy_path, cli_overrides=cli_overrides
)
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path

View File

@@ -39,9 +39,7 @@ from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import get_safe_torch_device, has_method
def log_control_info(
robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None
):
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
log_items = []
if episode_index is not None:
log_items.append(f"ep:{episode_index}")
@@ -108,9 +106,7 @@ def predict_action(observation, policy, device, use_amp):
observation = copy(observation)
with (
torch.inference_mode(),
torch.autocast(device_type=device.type)
if device.type == "cuda" and use_amp
else nullcontext(),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
@@ -166,9 +162,7 @@ def init_keyboard_listener(assign_rewards=False):
print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True
elif key == keyboard.Key.left:
print(
"Left arrow key pressed. Exiting loop and rerecord the last episode..."
)
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
@@ -262,9 +256,7 @@ def control_loop(
raise ValueError("You need to provide a task as argument in `single_task`.")
if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(
f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps})."
)
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
timestamp = 0
start_episode_t = time.perf_counter()
@@ -302,9 +294,7 @@ def control_loop(
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
)
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
if fps is not None:
@@ -392,14 +382,11 @@ def sanity_check_dataset_robot_compatibility(
mismatches = []
for field, dataset_value, present_value in fields:
diff = DeepDiff(
dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"]
)
diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"])
if diff:
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
if mismatches:
raise ValueError(
"Dataset metadata compatibility check failed with mismatches:\n"
+ "\n".join(mismatches)
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
)

View File

@@ -161,9 +161,7 @@ NUM_READ_RETRY = 10
NUM_WRITE_RETRY = 10
def convert_degrees_to_steps(
degrees: float | np.ndarray, models: str | list[str]
) -> np.ndarray:
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
"""This function converts the degree range to the step range for indicating motors rotation.
It assumes a motor achieves a full rotation by going from -180 degree position to +180.
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
@@ -389,9 +387,7 @@ class DynamixelMotorsBus:
indices = []
for idx in tqdm.tqdm(possible_ids):
try:
present_idx = self.read_with_motor_ids(
self.motor_models, [idx], "ID", num_retry=num_retry
)[0]
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
except ConnectionError:
continue
@@ -407,9 +403,7 @@ class DynamixelMotorsBus:
def set_bus_baudrate(self, baudrate):
present_bus_baudrate = self.port_handler.getBaudRate()
if present_bus_baudrate != baudrate:
print(
f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}."
)
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
self.port_handler.setBaudRate(baudrate)
if self.port_handler.getBaudRate() != baudrate:
@@ -430,9 +424,7 @@ class DynamixelMotorsBus:
def set_calibration(self, calibration: dict[str, list]):
self.calibration = calibration
def apply_calibration_autocorrect(
self, values: np.ndarray | list, motor_names: list[str] | None
):
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
"""This function applies the calibration, automatically detects out of range errors for motors values and attempts to correct.
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
@@ -445,9 +437,7 @@ class DynamixelMotorsBus:
values = self.apply_calibration(values, motor_names)
return values
def apply_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
a "zero position" at 0 degree.
@@ -522,9 +512,7 @@ class DynamixelMotorsBus:
return values
def autocorrect_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
"""This function automatically detects issues with values of motors after calibration, and correct for these issues.
Some motors might have values outside of expected maximum bounds after calibration.
@@ -566,23 +554,15 @@ class DynamixelMotorsBus:
values[i] *= -1
# Convert from initial range to range [-180, 180] degrees
calib_val = (
(values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
)
in_range = (calib_val > LOWER_BOUND_DEGREE) and (
calib_val < UPPER_BOUND_DEGREE
)
calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
# (- (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= ((resolution // 2) - values[i] - homing_offset) / resolution
low_factor = (
-(resolution // 2) - values[i] - homing_offset
) / resolution
upp_factor = (
(resolution // 2) - values[i] - homing_offset
) / resolution
low_factor = (-(resolution // 2) - values[i] - homing_offset) / resolution
upp_factor = ((resolution // 2) - values[i] - homing_offset) / resolution
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
start_pos = self.calibration["start_pos"][calib_idx]
@@ -590,9 +570,7 @@ class DynamixelMotorsBus:
# Convert from initial range to range [0, 100] in %
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
in_range = (calib_val > LOWER_BOUND_LINEAR) and (
calib_val < UPPER_BOUND_LINEAR
)
in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
# Solve this inequality to find the factor to shift the range into [0, 100] %
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
@@ -608,27 +586,19 @@ class DynamixelMotorsBus:
factor = math.ceil(low_factor)
if factor > upp_factor:
raise ValueError(
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
)
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
else:
factor = math.ceil(upp_factor)
if factor > low_factor:
raise ValueError(
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
)
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
out_of_range_str = (
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
)
in_range_str = (
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
)
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
logging.warning(
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
@@ -638,9 +608,7 @@ class DynamixelMotorsBus:
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
self.calibration["homing_offset"][calib_idx] += resolution * factor
def revert_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
"""Inverse of `apply_calibration`."""
if motor_names is None:
motor_names = self.motor_names
@@ -679,9 +647,7 @@ class DynamixelMotorsBus:
values = np.round(values).astype(np.int32)
return values
def read_with_motor_ids(
self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY
):
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
if self.mock:
import tests.motors.mock_dynamixel_sdk as dxl
else:
@@ -783,9 +749,7 @@ class DynamixelMotorsBus:
values = self.apply_calibration_autocorrect(values, motor_names)
# log the number of seconds it took to read the data from the motors
delta_ts_name = get_log_name(
"delta_timestamp_s", "read", data_name, motor_names
)
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
self.logs[delta_ts_name] = time.perf_counter() - start_time
# log the utc time at which the data was received
@@ -794,9 +758,7 @@ class DynamixelMotorsBus:
return values
def write_with_motor_ids(
self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY
):
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
if self.mock:
import tests.motors.mock_dynamixel_sdk as dxl
else:
@@ -891,9 +853,7 @@ class DynamixelMotorsBus:
)
# log the number of seconds it took to write the data to the motors
delta_ts_name = get_log_name(
"delta_timestamp_s", "write", data_name, motor_names
)
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
self.logs[delta_ts_name] = time.perf_counter() - start_time
# TODO(rcadene): should we log the time before sending the write command?

View File

@@ -140,9 +140,7 @@ NUM_READ_RETRY = 20
NUM_WRITE_RETRY = 20
def convert_degrees_to_steps(
degrees: float | np.ndarray, models: str | list[str]
) -> np.ndarray:
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
"""This function converts the degree range to the step range for indicating motors rotation.
It assumes a motor achieves a full rotation by going from -180 degree position to +180.
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
@@ -370,9 +368,7 @@ class FeetechMotorsBus:
indices = []
for idx in tqdm.tqdm(possible_ids):
try:
present_idx = self.read_with_motor_ids(
self.motor_models, [idx], "ID", num_retry=num_retry
)[0]
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
except ConnectionError:
continue
@@ -388,9 +384,7 @@ class FeetechMotorsBus:
def set_bus_baudrate(self, baudrate):
present_bus_baudrate = self.port_handler.getBaudRate()
if present_bus_baudrate != baudrate:
print(
f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}."
)
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
self.port_handler.setBaudRate(baudrate)
if self.port_handler.getBaudRate() != baudrate:
@@ -411,9 +405,7 @@ class FeetechMotorsBus:
def set_calibration(self, calibration: dict[str, list]):
self.calibration = calibration
def apply_calibration_autocorrect(
self, values: np.ndarray | list, motor_names: list[str] | None
):
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
"""This function apply the calibration, automatically detects out of range errors for motors values and attempt to correct.
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
@@ -426,9 +418,7 @@ class FeetechMotorsBus:
values = self.apply_calibration(values, motor_names)
return values
def apply_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
a "zero position" at 0 degree.
@@ -502,9 +492,7 @@ class FeetechMotorsBus:
return values
def autocorrect_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
"""This function automatically detects issues with values of motors after calibration, and correct for these issues.
Some motors might have values outside of expected maximum bounds after calibration.
@@ -543,26 +531,18 @@ class FeetechMotorsBus:
values[i] *= -1
# Convert from initial range to range [-180, 180] degrees
calib_val = (
(values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
)
in_range = (calib_val > LOWER_BOUND_DEGREE) and (
calib_val < UPPER_BOUND_DEGREE
)
calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
# (- HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= (HALF_TURN_DEGREE / 180 * (resolution // 2) - values[i] - homing_offset) / resolution
low_factor = (
-HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2)
- values[i]
- homing_offset
-HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset
) / resolution
upp_factor = (
HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2)
- values[i]
- homing_offset
HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset
) / resolution
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
@@ -571,9 +551,7 @@ class FeetechMotorsBus:
# Convert from initial range to range [0, 100] in %
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
in_range = (calib_val > LOWER_BOUND_LINEAR) and (
calib_val < UPPER_BOUND_LINEAR
)
in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
# Solve this inequality to find the factor to shift the range into [0, 100] %
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
@@ -589,27 +567,19 @@ class FeetechMotorsBus:
factor = math.ceil(low_factor)
if factor > upp_factor:
raise ValueError(
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
)
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
else:
factor = math.ceil(upp_factor)
if factor > low_factor:
raise ValueError(
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
)
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
out_of_range_str = (
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
)
in_range_str = (
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
)
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
logging.warning(
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
@@ -619,9 +589,7 @@ class FeetechMotorsBus:
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
self.calibration["homing_offset"][calib_idx] += resolution * factor
def revert_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
"""Inverse of `apply_calibration`."""
if motor_names is None:
motor_names = self.motor_names
@@ -697,9 +665,7 @@ class FeetechMotorsBus:
return values
def read_with_motor_ids(
self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY
):
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
if self.mock:
import tests.motors.mock_scservo_sdk as scs
else:
@@ -808,9 +774,7 @@ class FeetechMotorsBus:
values = self.apply_calibration_autocorrect(values, motor_names)
# log the number of seconds it took to read the data from the motors
delta_ts_name = get_log_name(
"delta_timestamp_s", "read", data_name, motor_names
)
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
self.logs[delta_ts_name] = time.perf_counter() - start_time
# log the utc time at which the data was received
@@ -819,9 +783,7 @@ class FeetechMotorsBus:
return values
def write_with_motor_ids(
self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY
):
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
if self.mock:
import tests.motors.mock_scservo_sdk as scs
else:
@@ -916,9 +878,7 @@ class FeetechMotorsBus:
)
# log the number of seconds it took to write the data to the motors
delta_ts_name = get_log_name(
"delta_timestamp_s", "write", data_name, motor_names
)
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
self.logs[delta_ts_name] = time.perf_counter() - start_time
# TODO(rcadene): should we log the time before sending the write command?

View File

@@ -69,13 +69,9 @@ class ManipulatorRobotConfig(RobotConfig):
if not cam.mock:
cam.mock = True
if self.max_relative_target is not None and isinstance(
self.max_relative_target, Sequence
):
if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence):
for name in self.follower_arms:
if len(self.follower_arms[name].motors) != len(
self.max_relative_target
):
if len(self.follower_arms[name].motors) != len(self.max_relative_target):
raise ValueError(
f"len(max_relative_target)={len(self.max_relative_target)} but the follower arm with name {name} has "
f"{len(self.follower_arms[name].motors)} motors. Please make sure that the "

View File

@@ -24,7 +24,9 @@ from lerobot.common.robot_devices.motors.dynamixel import (
)
from lerobot.common.robot_devices.motors.utils import MotorsBus
URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
URL_TEMPLATE = (
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
)
# The following positions are provided in nominal degree range ]-180, +180[
# For more info on these constants, see comments in the code where they get used.
@@ -35,9 +37,7 @@ ROTATED_POSITION_DEGREE = 90
def assert_drive_mode(drive_mode):
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
if not np.all(np.isin(drive_mode, [0, 1])):
raise ValueError(
f"`drive_mode` contains values other than 0 or 1: ({drive_mode})"
)
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
def apply_drive_mode(position, drive_mode):
@@ -78,16 +78,12 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
```
"""
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError(
"To run calibration, the torque must be disabled on all motors."
)
raise ValueError("To run calibration, the torque must be disabled on all motors.")
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to zero position")
print(
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")
)
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero"))
input("Press Enter to continue...")
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
@@ -108,15 +104,10 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
# of the previous motor in the kinetic chain.
print("\nMove arm to rotated target position")
print(
"See: "
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")
)
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
input("Press Enter to continue...")
rotated_target_pos = convert_degrees_to_steps(
ROTATED_POSITION_DEGREE, arm.motor_models
)
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
# Find drive mode by rotating each motor by a quarter of a turn.
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
@@ -125,15 +116,11 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
# Re-compute homing offset to take into account drive mode
rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode)
rotated_nearest_pos = compute_nearest_rounded_position(
rotated_drived_pos, arm.motor_models
)
rotated_nearest_pos = compute_nearest_rounded_position(rotated_drived_pos, arm.motor_models)
homing_offset = rotated_target_pos - rotated_nearest_pos
print("\nMove arm to rest position")
print(
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")
)
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest"))
input("Press Enter to continue...")
print()

View File

@@ -26,7 +26,9 @@ from lerobot.common.robot_devices.motors.feetech import (
)
from lerobot.common.robot_devices.motors.utils import MotorsBus
URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
URL_TEMPLATE = (
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
)
# The following positions are provided in nominal degree range ]-180, +180[
# For more info on these constants, see comments in the code where they get used.
@@ -37,9 +39,7 @@ ROTATED_POSITION_DEGREE = 90
def assert_drive_mode(drive_mode):
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
if not np.all(np.isin(drive_mode, [0, 1])):
raise ValueError(
f"`drive_mode` contains values other than 0 or 1: ({drive_mode})"
)
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
def apply_drive_mode(position, drive_mode):
@@ -140,9 +140,7 @@ def apply_offset(calib, offset):
return calib
def run_arm_auto_calibration(
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
):
def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
if robot_type == "so100":
return run_arm_auto_calibration_so100(arm, robot_type, arm_name, arm_type)
elif robot_type == "moss":
@@ -151,27 +149,18 @@ def run_arm_auto_calibration(
raise ValueError(robot_type)
def run_arm_auto_calibration_so100(
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
):
def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
"""All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError(
"To run calibration, the torque must be disabled on all motors."
)
raise ValueError("To run calibration, the torque must be disabled on all motors.")
if not (robot_type == "so100" and arm_type == "follower"):
raise NotImplementedError(
"Auto calibration only supports the follower of so100 arms for now."
)
raise NotImplementedError("Auto calibration only supports the follower of so100 arms for now.")
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to initial position")
print(
"See: "
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")
)
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial"))
input("Press Enter to continue...")
# Lower the acceleration of the motors (in [0,254])
@@ -225,9 +214,7 @@ def run_arm_auto_calibration_so100(
)
calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024)
arm.write(
"Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex"
)
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex")
time.sleep(1)
def in_between_move_hook():
@@ -261,13 +248,9 @@ def run_arm_auto_calibration_so100(
"shoulder_lift",
)
time.sleep(2)
arm.write(
"Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex"
)
arm.write("Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex")
time.sleep(2)
arm.write(
"Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex"
)
arm.write("Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex")
time.sleep(2)
arm.write("Goal_Position", round(calib["gripper"]["end_pos"]), "gripper")
time.sleep(2)
@@ -288,9 +271,7 @@ def run_arm_auto_calibration_so100(
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"], "wrist_flex")
time.sleep(1)
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 2048, "elbow_flex")
arm.write(
"Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift"
)
arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift")
time.sleep(1)
arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan")
time.sleep(1)
@@ -319,27 +300,18 @@ def run_arm_auto_calibration_so100(
return calib_dict
def run_arm_auto_calibration_moss(
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
):
def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
"""All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError(
"To run calibration, the torque must be disabled on all motors."
)
raise ValueError("To run calibration, the torque must be disabled on all motors.")
if not (robot_type == "moss" and arm_type == "follower"):
raise NotImplementedError(
"Auto calibration only supports the follower of moss arms for now."
)
raise NotImplementedError("Auto calibration only supports the follower of moss arms for now.")
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to initial position")
print(
"See: "
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")
)
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial"))
input("Press Enter to continue...")
# Lower the acceleration of the motors (in [0,254])
@@ -423,12 +395,8 @@ def run_arm_auto_calibration_moss(
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
time.sleep(1)
arm.write(
"Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift"
)
arm.write(
"Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex"
)
arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift")
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex")
time.sleep(2)
calib_modes = []
@@ -455,9 +423,7 @@ def run_arm_auto_calibration_moss(
return calib_dict
def run_arm_manual_calibration(
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
):
def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
"""This function ensures that a neural network trained on data collected on a given robot
can work on another robot. For instance before calibration, setting a same goal position
for each motor of two different robots will get two very different positions. But after calibration,
@@ -480,16 +446,12 @@ def run_arm_manual_calibration(
```
"""
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError(
"To run calibration, the torque must be disabled on all motors."
)
raise ValueError("To run calibration, the torque must be disabled on all motors.")
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to zero position")
print(
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")
)
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero"))
input("Press Enter to continue...")
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
@@ -509,15 +471,10 @@ def run_arm_manual_calibration(
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
# of the previous motor in the kinetic chain.
print("\nMove arm to rotated target position")
print(
"See: "
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")
)
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
input("Press Enter to continue...")
rotated_target_pos = convert_degrees_to_steps(
ROTATED_POSITION_DEGREE, arm.motor_models
)
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
# Find drive mode by rotating each motor by a quarter of a turn.
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
@@ -529,9 +486,7 @@ def run_arm_manual_calibration(
homing_offset = rotated_target_pos - rotated_drived_pos
print("\nMove arm to rest position")
print(
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")
)
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest"))
input("Press Enter to continue...")
print()

View File

@@ -42,9 +42,7 @@ def run_camera_capture(cameras, images_lock, latest_images_dict, stop_event):
local_dict = {}
for name, cam in cameras.items():
frame = cam.async_read()
ret, buffer = cv2.imencode(
".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90]
)
ret, buffer = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
if ret:
local_dict[name] = base64.b64encode(buffer).decode("utf-8")
else:
@@ -76,9 +74,7 @@ def calibrate_follower_arm(motors_bus, calib_dir_str):
print(f"[INFO] Loaded calibration from {calib_file}")
else:
print("[INFO] Calibration file not found. Running manual calibration...")
calibration = run_arm_manual_calibration(
motors_bus, "lekiwi", "follower_arm", "follower"
)
calibration = run_arm_manual_calibration(motors_bus, "lekiwi", "follower_arm", "follower")
print(f"[INFO] Calibration complete. Saving to {calib_file}")
with open(calib_file, "w") as f:
json.dump(calibration, f)
@@ -174,9 +170,7 @@ def run_lekiwi(robot_config):
f"[WARNING] Received {len(arm_positions)} arm positions, expected {len(arm_motor_ids)}"
)
else:
for motor, pos in zip(
arm_motor_ids, arm_positions, strict=False
):
for motor, pos in zip(arm_motor_ids, arm_positions, strict=False):
motors_bus.write("Goal_Position", pos, motor)
# Process wheel (base) commands.
if "raw_velocity" in data:
@@ -207,9 +201,7 @@ def run_lekiwi(robot_config):
try:
pos = motors_bus.read("Present_Position", motor)
# Convert the position to a float (or use as is if already numeric).
follower_arm_state.append(
float(pos) if not isinstance(pos, (int, float)) else pos
)
follower_arm_state.append(float(pos) if not isinstance(pos, (int, float)) else pos)
except Exception as e:
print(f"[ERROR] Reading motor {motor} failed: {e}")

View File

@@ -285,9 +285,7 @@ class ManipulatorRobot:
# to squeeze the gripper and have it spring back to an open position on its own.
for name in self.leader_arms:
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
self.leader_arms[name].write(
"Goal_Position", self.config.gripper_open_degree, "gripper"
)
self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper")
# Check both arms can be read
for name in self.follower_arms:
@@ -323,22 +321,16 @@ class ManipulatorRobot:
run_arm_calibration,
)
calibration = run_arm_calibration(
arm, self.robot_type, name, arm_type
)
calibration = run_arm_calibration(arm, self.robot_type, name, arm_type)
elif self.robot_type in ["so100", "moss", "lekiwi"]:
from lerobot.common.robot_devices.robots.feetech_calibration import (
run_arm_manual_calibration,
)
calibration = run_arm_manual_calibration(
arm, self.robot_type, name, arm_type
)
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
print(
f"Calibration is done! Saving calibration file '{arm_calib_path}'"
)
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
with open(arm_calib_path, "w") as f:
json.dump(calibration, f)
@@ -357,17 +349,13 @@ class ManipulatorRobot:
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError(
"To run set robot preset, the torque must be disabled on all motors."
)
raise ValueError("To run set robot preset, the torque must be disabled on all motors.")
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't
# rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm,
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
all_motors_except_gripper = [
name for name in arm.motor_names if name != "gripper"
]
all_motors_except_gripper = [name for name in arm.motor_names if name != "gripper"]
if len(all_motors_except_gripper) > 0:
# 4 corresponds to Extended Position on Koch motors
arm.write("Operating_Mode", 4, all_motors_except_gripper)
@@ -396,9 +384,7 @@ class ManipulatorRobot:
# Enable torque on the gripper of the leader arms, and move it to 45 degrees,
# so that we can use it as a trigger to close the gripper of the follower arms.
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
self.leader_arms[name].write(
"Goal_Position", self.config.gripper_open_degree, "gripper"
)
self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper")
def set_aloha_robot_preset(self):
def set_shadow_(arm):
@@ -428,15 +414,11 @@ class ManipulatorRobot:
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
all_motors_except_gripper = [
name
for name in self.follower_arms[name].motor_names
if name != "gripper"
name for name in self.follower_arms[name].motor_names if name != "gripper"
]
if len(all_motors_except_gripper) > 0:
# 4 corresponds to Extended Position on Aloha motors
self.follower_arms[name].write(
"Operating_Mode", 4, all_motors_except_gripper
)
self.follower_arms[name].write("Operating_Mode", 4, all_motors_except_gripper)
# Use 'position control current based' for follower gripper to be limited by the limit of the current.
# It can grasp an object without forcing too much even tho,
@@ -484,9 +466,7 @@ class ManipulatorRobot:
before_lread_t = time.perf_counter()
leader_pos[name] = self.leader_arms[name].read("Present_Position")
leader_pos[name] = torch.from_numpy(leader_pos[name])
self.logs[f"read_leader_{name}_pos_dt_s"] = (
time.perf_counter() - before_lread_t
)
self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t
# Send goal position to the follower
follower_goal_pos = {}
@@ -507,18 +487,14 @@ class ManipulatorRobot:
if self.config.max_relative_target is not None:
present_pos = self.follower_arms[name].read("Present_Position")
present_pos = torch.from_numpy(present_pos)
goal_pos = ensure_safe_goal_position(
goal_pos, present_pos, self.config.max_relative_target
)
goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
# Used when record_data=True
follower_goal_pos[name] = goal_pos
goal_pos = goal_pos.numpy().astype(np.float32)
self.follower_arms[name].write("Goal_Position", goal_pos)
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = (
time.perf_counter() - before_fwrite_t
)
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t
# Early exit when recording data is not requested
if not record_data:
@@ -531,9 +507,7 @@ class ManipulatorRobot:
before_fread_t = time.perf_counter()
follower_pos[name] = self.follower_arms[name].read("Present_Position")
follower_pos[name] = torch.from_numpy(follower_pos[name])
self.logs[f"read_follower_{name}_pos_dt_s"] = (
time.perf_counter() - before_fread_t
)
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
# Create state by concatenating follower current position
state = []
@@ -555,12 +529,8 @@ class ManipulatorRobot:
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
"delta_timestamp_s"
]
self.logs[f"async_read_camera_{name}_dt_s"] = (
time.perf_counter() - before_camread_t
)
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Populate output dictionaries
obs_dict, action_dict = {}, {}
@@ -584,9 +554,7 @@ class ManipulatorRobot:
before_fread_t = time.perf_counter()
follower_pos[name] = self.follower_arms[name].read("Present_Position")
follower_pos[name] = torch.from_numpy(follower_pos[name])
self.logs[f"read_follower_{name}_pos_dt_s"] = (
time.perf_counter() - before_fread_t
)
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
# Create state by concatenating follower current position
state = []
@@ -601,12 +569,8 @@ class ManipulatorRobot:
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
"delta_timestamp_s"
]
self.logs[f"async_read_camera_{name}_dt_s"] = (
time.perf_counter() - before_camread_t
)
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Populate output dictionaries and format to pytorch
obs_dict = {}
@@ -652,9 +616,7 @@ class ManipulatorRobot:
if self.config.max_relative_target is not None:
present_pos = self.follower_arms[name].read("Present_Position")
present_pos = torch.from_numpy(present_pos)
goal_pos = ensure_safe_goal_position(
goal_pos, present_pos, self.config.max_relative_target
)
goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
# Save tensor to concat and return
action_sent.append(goal_pos)

View File

@@ -271,9 +271,7 @@ class MobileManipulator:
calibration = json.load(f)
else:
print(f"Missing calibration file '{arm_calib_path}'")
calibration = run_arm_manual_calibration(
arm, self.robot_type, name, arm_type
)
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
with open(arm_calib_path, "w") as f:
@@ -303,9 +301,7 @@ class MobileManipulator:
bus.write("Torque_Enable", 0, motor_id)
# Then filter out wheels
arm_only_dict = {
k: v for k, v in bus.motors.items() if not k.startswith("wheel_")
}
arm_only_dict = {k: v for k, v in bus.motors.items() if not k.startswith("wheel_")}
if not arm_only_dict:
continue
@@ -377,9 +373,7 @@ class MobileManipulator:
if new_arm_state is not None and frames is not None:
self.last_frames = frames
remote_arm_state_tensor = torch.tensor(
new_arm_state, dtype=torch.float32
)
remote_arm_state_tensor = torch.tensor(new_arm_state, dtype=torch.float32)
self.last_remote_arm_state = remote_arm_state_tensor
present_speed = new_speed
@@ -405,10 +399,7 @@ class MobileManipulator:
def _process_present_speed(self, present_speed: dict) -> torch.Tensor:
state_tensor = torch.zeros(3, dtype=torch.int32)
if present_speed:
decoded = {
key: MobileManipulator.raw_to_degps(value)
for key, value in present_speed.items()
}
decoded = {key: MobileManipulator.raw_to_degps(value) for key, value in present_speed.items()}
if "1" in decoded:
state_tensor[0] = decoded["1"]
if "2" in decoded:
@@ -421,9 +412,7 @@ class MobileManipulator:
self, record_data: bool = False
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
if not self.is_connected:
raise RobotDeviceNotConnectedError(
"MobileManipulator is not connected. Run `connect()` first."
)
raise RobotDeviceNotConnectedError("MobileManipulator is not connected. Run `connect()` first.")
speed_setting = self.speed_levels[self.speed_index]
xy_speed = speed_setting["xy"] # e.g. 0.1, 0.25, or 0.4
@@ -495,9 +484,7 @@ class MobileManipulator:
body_state[2],
) # Convert x,y to mm/s
wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32)
combined_state_tensor = torch.cat(
(remote_arm_state_tensor, wheel_state_tensor), dim=0
)
combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0)
obs_dict = {"observation.state": combined_state_tensor}

View File

@@ -52,9 +52,7 @@ class StretchRobot(StretchAPI):
def connect(self) -> None:
self.is_connected = self.startup()
if not self.is_connected:
print(
"Another process is already using Stretch. Try running 'stretch_free_robot_process.py'"
)
print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'")
raise ConnectionError()
for name in self.cameras:
@@ -62,9 +60,7 @@ class StretchRobot(StretchAPI):
self.is_connected = self.is_connected and self.cameras[name].is_connected
if not self.is_connected:
print(
"Could not connect to the cameras, check that all cameras are plugged-in."
)
print("Could not connect to the cameras, check that all cameras are plugged-in.")
raise ConnectionError()
self.run_calibration()
@@ -109,12 +105,8 @@ class StretchRobot(StretchAPI):
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
"delta_timestamp_s"
]
self.logs[f"async_read_camera_{name}_dt_s"] = (
time.perf_counter() - before_camread_t
)
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Populate output dictionaries
obs_dict, action_dict = {}, {}
@@ -158,12 +150,8 @@ class StretchRobot(StretchAPI):
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
"delta_timestamp_s"
]
self.logs[f"async_read_camera_{name}_dt_s"] = (
time.perf_counter() - before_camread_t
)
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Populate output dictionaries
obs_dict = {}

View File

@@ -69,9 +69,7 @@ class HubMixin:
if push_to_hub:
if repo_id is None:
repo_id = save_directory.name # Defaults to `save_directory` name
return self.push_to_hub(
repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs
)
return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs)
return None
def _save_pretrained(self, save_directory: Path) -> None:
@@ -177,9 +175,7 @@ class HubMixin:
The url of the commit of your object in the given repository.
"""
api = HfApi(token=token)
repo_id = api.create_repo(
repo_id=repo_id, private=private, exist_ok=True
).repo_id
repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
if commit_message is None:
if "Policy" in self.__class__.__name__:

View File

@@ -17,9 +17,7 @@ import importlib
import logging
def is_package_available(
pkg_name: str, return_version: bool = False
) -> tuple[bool, str] | bool:
def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool:
"""Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py
Check if the package spec exists and grab its version to avoid importing a local directory.
**Note:** this doesn't work for all packages.

View File

@@ -20,16 +20,7 @@ from typing import TypeVar
import imageio
JsonLike = (
str
| int
| float
| bool
| None
| list["JsonLike"]
| dict[str, "JsonLike"]
| tuple["JsonLike", ...]
)
JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...]
T = TypeVar("T", bound=JsonLike)
@@ -85,9 +76,7 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
# Check length
if len(target) != len(source):
raise ValueError(
f"List length mismatch: expected {len(target)}, got {len(source)}"
)
raise ValueError(f"List length mismatch: expected {len(target)}, got {len(source)}")
# Recursively update each element.
for i in range(len(target)):
@@ -99,14 +88,10 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
# which we'll convert back to a tuple.
elif isinstance(target, tuple):
if not isinstance(source, list):
raise TypeError(
f"Type mismatch: expected list (for tuple), got {type(source)}"
)
raise TypeError(f"Type mismatch: expected list (for tuple), got {type(source)}")
if len(target) != len(source):
raise ValueError(
f"Tuple length mismatch: expected {len(target)}, got {len(source)}"
)
raise ValueError(f"Tuple length mismatch: expected {len(target)}, got {len(source)}")
# Convert each element, forming a new tuple.
converted_items = []
@@ -120,9 +105,7 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
else:
# Check the exact type. If these must match 1:1, do:
if type(target) is not type(source):
raise TypeError(
f"Type mismatch: expected {type(target)}, got {type(source)}"
)
raise TypeError(f"Type mismatch: expected {type(target)}, got {type(source)}")
return source
# Perform the in-place/recursive deserialization

View File

@@ -107,17 +107,13 @@ class MetricsTracker:
self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames
def __getattr__(
self, name: str
) -> int | dict[str, AverageMeter] | AverageMeter | Any:
def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any:
if name in self.__dict__:
return self.__dict__[name]
elif name in self.metrics:
return self.metrics[name]
else:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
def __setattr__(self, name: str, value: Any) -> None:
if name in self.__dict__:
@@ -125,9 +121,7 @@ class MetricsTracker:
elif name in self.metrics:
self.metrics[name].update(value)
else:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
def step(self) -> None:
"""

View File

@@ -123,9 +123,7 @@ def deserialize_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
"""
py_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("py")}
np_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("np")}
torch_rng_state_dict = {
k: v for k, v in rng_state_dict.items() if k.startswith("torch")
}
torch_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("torch")}
deserialize_python_rng_state(py_rng_state_dict)
deserialize_numpy_rng_state(np_rng_state_dict)

View File

@@ -48,9 +48,7 @@ def auto_select_torch_device() -> torch.device:
logging.info("Metal backend detected, using cuda.")
return torch.device("mps")
else:
logging.warning(
"No accelerated backend detected. Using default cpu, this will be slow."
)
logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
return torch.device("cpu")
@@ -98,9 +96,7 @@ def is_torch_device_available(try_device: str) -> bool:
elif try_device == "cpu":
return True
else:
raise ValueError(
f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu."
)
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
def is_amp_available(device: str):
@@ -158,10 +154,7 @@ def _relative_path_between(path1: Path, path2: Path) -> Path:
except ValueError: # most likely because path1 is not a subpath of path2
common_parts = Path(osp.commonpath([path1, path2])).parts
return Path(
"/".join(
[".."] * (len(path2.parts) - len(common_parts))
+ list(path1.parts[len(common_parts) :])
)
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
)
@@ -172,26 +165,10 @@ def print_cuda_memory_usage():
gc.collect()
# Also clear the cache if you want to fully release the memory
torch.cuda.empty_cache()
print(
"Current GPU Memory Allocated: {:.2f} MB".format(
torch.cuda.memory_allocated(0) / 1024**2
)
)
print(
"Maximum GPU Memory Allocated: {:.2f} MB".format(
torch.cuda.max_memory_allocated(0) / 1024**2
)
)
print(
"Current GPU Memory Reserved: {:.2f} MB".format(
torch.cuda.memory_reserved(0) / 1024**2
)
)
print(
"Maximum GPU Memory Reserved: {:.2f} MB".format(
torch.cuda.max_memory_reserved(0) / 1024**2
)
)
print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2))
print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2))
print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2))
print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2))
def capture_timestamp_utc():
@@ -223,9 +200,7 @@ def say(text, blocking=False):
if blocking:
subprocess.run(cmd, check=True)
else:
subprocess.Popen(
cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0
)
subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0)
def log_say(text, play_sounds, blocking=False):

View File

@@ -26,9 +26,7 @@ from lerobot.common.constants import PRETRAINED_MODEL_DIR
from lerobot.configs.train import TrainPipelineConfig
def cfg_to_group(
cfg: TrainPipelineConfig, return_list: bool = False
) -> list[str] | str:
def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
"""Return a group name for logging. Optionally returns group name as list."""
lst = [
f"policy:{cfg.policy.type}",
@@ -95,9 +93,7 @@ class WandBLogger:
mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online",
)
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(
f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}"
)
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb
def log_policy(self, checkpoint_dir: Path):
@@ -109,9 +105,7 @@ class WandBLogger:
artifact_name = f"{self._group}-{step_id}"
artifact_name = get_safe_wandb_artifact_name(artifact_name)
artifact = self._wandb.Artifact(artifact_name, type="model")
artifact.add_file(
checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE
)
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
self._wandb.log_artifact(artifact)
def log_dict(self, d: dict, step: int, mode: str = "train"):