[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
AdilZouitine
parent
2945bbb221
commit
7c05755823
@@ -46,18 +46,14 @@ def sample_indices(data_len: int) -> list[int]:
|
||||
return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist()
|
||||
|
||||
|
||||
def auto_downsample_height_width(
|
||||
img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300
|
||||
):
|
||||
def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300):
|
||||
_, height, width = img.shape
|
||||
|
||||
if max(width, height) < max_size_threshold:
|
||||
# no downsampling needed
|
||||
return img
|
||||
|
||||
downsample_factor = (
|
||||
int(width / target_size) if width > height else int(height / target_size)
|
||||
)
|
||||
downsample_factor = int(width / target_size) if width > height else int(height / target_size)
|
||||
return img[:, ::downsample_factor, ::downsample_factor]
|
||||
|
||||
|
||||
@@ -79,9 +75,7 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
|
||||
return images
|
||||
|
||||
|
||||
def get_feature_stats(
|
||||
array: np.ndarray, axis: tuple, keepdims: bool
|
||||
) -> dict[str, np.ndarray]:
|
||||
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
|
||||
return {
|
||||
"min": np.min(array, axis=axis, keepdims=keepdims),
|
||||
"max": np.max(array, axis=axis, keepdims=keepdims),
|
||||
@@ -91,9 +85,7 @@ def get_feature_stats(
|
||||
}
|
||||
|
||||
|
||||
def compute_episode_stats(
|
||||
episode_data: dict[str, list[str] | np.ndarray], features: dict
|
||||
) -> dict:
|
||||
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
|
||||
ep_stats = {}
|
||||
for key, data in episode_data.items():
|
||||
if features[key]["dtype"] == "string":
|
||||
@@ -107,15 +99,12 @@ def compute_episode_stats(
|
||||
axes_to_reduce = 0 # compute stats over the first axis
|
||||
keepdims = data.ndim == 1 # keep as np.array
|
||||
|
||||
ep_stats[key] = get_feature_stats(
|
||||
ep_ft_array, axis=axes_to_reduce, keepdims=keepdims
|
||||
)
|
||||
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
|
||||
|
||||
# finally, we normalize and remove batch dim for images
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v / 255.0, axis=0)
|
||||
for k, v in ep_stats[key].items()
|
||||
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
return ep_stats
|
||||
@@ -130,17 +119,11 @@ def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
||||
f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
|
||||
)
|
||||
if v.ndim == 0:
|
||||
raise ValueError(
|
||||
"Number of dimensions must be at least 1, and is 0 instead."
|
||||
)
|
||||
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
|
||||
if k == "count" and v.shape != (1,):
|
||||
raise ValueError(
|
||||
f"Shape of 'count' must be (1), but is {v.shape} instead."
|
||||
)
|
||||
raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
|
||||
if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
|
||||
raise ValueError(
|
||||
f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead."
|
||||
)
|
||||
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
|
||||
|
||||
|
||||
def aggregate_feature_stats(
|
||||
|
||||
@@ -58,9 +58,7 @@ def resolve_delta_timestamps(
|
||||
if key == "action" and cfg.action_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
|
||||
if key.startswith("observation.") and cfg.observation_delta_indices is not None:
|
||||
delta_timestamps[key] = [
|
||||
i / ds_meta.fps for i in cfg.observation_delta_indices
|
||||
]
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
|
||||
|
||||
if len(delta_timestamps) == 0:
|
||||
delta_timestamps = None
|
||||
@@ -81,9 +79,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
LeRobotDataset | MultiLeRobotDataset
|
||||
"""
|
||||
image_transforms = (
|
||||
ImageTransforms(cfg.dataset.image_transforms)
|
||||
if cfg.dataset.image_transforms.enable
|
||||
else None
|
||||
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
|
||||
)
|
||||
|
||||
if isinstance(cfg.dataset.repo_id, str):
|
||||
@@ -117,8 +113,6 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
if cfg.dataset.use_imagenet_stats:
|
||||
for key in dataset.meta.camera_keys:
|
||||
for stats_type, stats in IMAGENET_STATS.items():
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(
|
||||
stats, dtype=torch.float32
|
||||
)
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -38,14 +38,10 @@ def safe_stop_image_writer(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def image_array_to_pil_image(
|
||||
image_array: np.ndarray, range_check: bool = True
|
||||
) -> PIL.Image.Image:
|
||||
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
|
||||
# TODO(aliberts): handle 1 channel and 4 for depth images
|
||||
if image_array.ndim != 3:
|
||||
raise ValueError(
|
||||
f"The array has {image_array.ndim} dimensions, but 3 is expected for an image."
|
||||
)
|
||||
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
|
||||
|
||||
if image_array.shape[0] == 3:
|
||||
# Transpose from pytorch convention (C, H, W) to (H, W, C)
|
||||
@@ -131,9 +127,7 @@ class AsyncImageWriter:
|
||||
self._stopped = False
|
||||
|
||||
if num_threads <= 0 and num_processes <= 0:
|
||||
raise ValueError(
|
||||
"Number of threads and processes must be greater than zero."
|
||||
)
|
||||
raise ValueError("Number of threads and processes must be greater than zero.")
|
||||
|
||||
if self.num_processes == 0:
|
||||
# Use threading
|
||||
@@ -147,16 +141,12 @@ class AsyncImageWriter:
|
||||
# Use multiprocessing
|
||||
self.queue = multiprocessing.JoinableQueue()
|
||||
for _ in range(self.num_processes):
|
||||
p = multiprocessing.Process(
|
||||
target=worker_process, args=(self.queue, self.num_threads)
|
||||
)
|
||||
p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads))
|
||||
p.daemon = True
|
||||
p.start()
|
||||
self.processes.append(p)
|
||||
|
||||
def save_image(
|
||||
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path
|
||||
):
|
||||
def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path):
|
||||
if isinstance(image, torch.Tensor):
|
||||
# Convert tensor to numpy array to minimize main process time
|
||||
image = image.cpu().numpy()
|
||||
|
||||
@@ -108,9 +108,7 @@ class LeRobotDatasetMetadata:
|
||||
self.episodes = load_episodes(self.root)
|
||||
if self._version < packaging.version.parse("v2.1"):
|
||||
self.stats = load_stats(self.root)
|
||||
self.episodes_stats = backward_compatible_episodes_stats(
|
||||
self.stats, self.episodes
|
||||
)
|
||||
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
|
||||
else:
|
||||
self.episodes_stats = load_episodes_stats(self.root)
|
||||
self.stats = aggregate_stats(list(self.episodes_stats.values()))
|
||||
@@ -141,9 +139,7 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
||||
ep_chunk = self.get_episode_chunk(ep_index)
|
||||
fpath = self.video_path.format(
|
||||
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index
|
||||
)
|
||||
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
||||
return Path(fpath)
|
||||
|
||||
def get_episode_chunk(self, ep_index: int) -> int:
|
||||
@@ -187,11 +183,7 @@ class LeRobotDatasetMetadata:
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
return [
|
||||
key
|
||||
for key, ft in self.features.items()
|
||||
if ft["dtype"] in ["video", "image"]
|
||||
]
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
@@ -240,9 +232,7 @@ class LeRobotDatasetMetadata:
|
||||
Given a task in natural language, add it to the dictionary of tasks.
|
||||
"""
|
||||
if task in self.task_to_task_index:
|
||||
raise ValueError(
|
||||
f"The task '{task}' already exists and can't be added twice."
|
||||
)
|
||||
raise ValueError(f"The task '{task}' already exists and can't be added twice.")
|
||||
|
||||
task_index = self.info["total_tasks"]
|
||||
self.task_to_task_index[task] = task_index
|
||||
@@ -285,11 +275,7 @@ class LeRobotDatasetMetadata:
|
||||
write_episode(episode_dict, self.root)
|
||||
|
||||
self.episodes_stats[episode_index] = episode_stats
|
||||
self.stats = (
|
||||
aggregate_stats([self.stats, episode_stats])
|
||||
if self.stats
|
||||
else episode_stats
|
||||
)
|
||||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
|
||||
write_episode_stats(episode_index, episode_stats, self.root)
|
||||
|
||||
def update_video_info(self) -> None:
|
||||
@@ -299,9 +285,7 @@ class LeRobotDatasetMetadata:
|
||||
"""
|
||||
for key in self.video_keys:
|
||||
if not self.features[key].get("info", None):
|
||||
video_path = self.root / self.get_video_file_path(
|
||||
ep_index=0, vid_key=key
|
||||
)
|
||||
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||
|
||||
def __repr__(self):
|
||||
@@ -353,17 +337,13 @@ class LeRobotDatasetMetadata:
|
||||
# as this would break the dict flattening in the stats computation, which uses '/' as separator
|
||||
for key in features:
|
||||
if "/" in key:
|
||||
raise ValueError(
|
||||
f"Feature names should not contain '/'. Found '/' in feature '{key}'."
|
||||
)
|
||||
raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.")
|
||||
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
|
||||
obj.tasks, obj.task_to_task_index = {}, {}
|
||||
obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
|
||||
obj.info = create_empty_dataset_info(
|
||||
CODEBASE_VERSION, fps, robot_type, features, use_videos
|
||||
)
|
||||
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
|
||||
if len(obj.video_keys) > 0 and not use_videos:
|
||||
raise ValueError()
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
@@ -494,9 +474,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.video_backend = (
|
||||
video_backend if video_backend else get_safe_default_codec()
|
||||
)
|
||||
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
self.delta_indices = None
|
||||
|
||||
# Unused attributes
|
||||
@@ -509,39 +487,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.meta = LeRobotDatasetMetadata(
|
||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
if self.episodes is not None and self.meta._version >= packaging.version.parse(
|
||||
"v2.1"
|
||||
):
|
||||
episodes_stats = [
|
||||
self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes
|
||||
]
|
||||
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
|
||||
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
||||
self.stats = aggregate_stats(episodes_stats)
|
||||
|
||||
# Load actual data
|
||||
try:
|
||||
if force_cache_sync:
|
||||
raise FileNotFoundError
|
||||
assert all(
|
||||
(self.root / fpath).is_file()
|
||||
for fpath in self.get_episodes_file_paths()
|
||||
)
|
||||
assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
self.download_episodes(download_videos)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
|
||||
self.episode_data_index = get_episode_data_index(
|
||||
self.meta.episodes, self.episodes
|
||||
)
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
|
||||
# Check timestamps
|
||||
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
|
||||
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
|
||||
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
||||
check_timestamps_sync(
|
||||
timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s
|
||||
)
|
||||
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
||||
|
||||
# Setup delta_indices
|
||||
if self.delta_timestamps is not None:
|
||||
@@ -593,9 +560,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
hub_api.upload_folder(**upload_kwargs)
|
||||
|
||||
if not hub_api.file_exists(
|
||||
self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch
|
||||
):
|
||||
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
||||
)
|
||||
@@ -603,12 +568,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
if tag_version:
|
||||
with contextlib.suppress(RevisionNotFoundError):
|
||||
hub_api.delete_tag(
|
||||
self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset"
|
||||
)
|
||||
hub_api.create_tag(
|
||||
self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset"
|
||||
)
|
||||
hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
||||
hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
|
||||
def pull_from_repo(
|
||||
self,
|
||||
@@ -640,11 +601,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
|
||||
def get_episodes_file_paths(self) -> list[Path]:
|
||||
episodes = (
|
||||
self.episodes
|
||||
if self.episodes is not None
|
||||
else list(range(self.meta.total_episodes))
|
||||
)
|
||||
episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes))
|
||||
fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
|
||||
if len(self.meta.video_keys) > 0:
|
||||
video_files = [
|
||||
@@ -662,10 +619,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
path = str(self.root / "data")
|
||||
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
|
||||
else:
|
||||
files = [
|
||||
str(self.root / self.meta.get_data_file_path(ep_idx))
|
||||
for ep_idx in self.episodes
|
||||
]
|
||||
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
||||
hf_dataset = load_dataset("parquet", data_files=files, split="train")
|
||||
|
||||
# TODO(aliberts): hf_dataset.set_format("torch")
|
||||
@@ -675,9 +629,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def create_hf_dataset(self) -> datasets.Dataset:
|
||||
features = get_hf_features_from_features(self.features)
|
||||
ft_dict = {col: [] for col in features}
|
||||
hf_dataset = datasets.Dataset.from_dict(
|
||||
ft_dict, features=features, split="train"
|
||||
)
|
||||
hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
|
||||
|
||||
# TODO(aliberts): hf_dataset.set_format("torch")
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
@@ -691,20 +643,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
"""Number of frames in selected episodes."""
|
||||
return (
|
||||
len(self.hf_dataset)
|
||||
if self.hf_dataset is not None
|
||||
else self.meta.total_frames
|
||||
)
|
||||
return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
"""Number of episodes selected."""
|
||||
return (
|
||||
len(self.episodes)
|
||||
if self.episodes is not None
|
||||
else self.meta.total_episodes
|
||||
)
|
||||
return len(self.episodes) if self.episodes is not None else self.meta.total_episodes
|
||||
|
||||
@property
|
||||
def features(self) -> dict[str, dict]:
|
||||
@@ -718,24 +662,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
return get_hf_features_from_features(self.features)
|
||||
|
||||
def _get_query_indices(
|
||||
self, idx: int, ep_idx: int
|
||||
) -> tuple[dict[str, list[int | bool]]]:
|
||||
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
||||
ep_start = self.episode_data_index["from"][ep_idx]
|
||||
ep_end = self.episode_data_index["to"][ep_idx]
|
||||
query_indices = {
|
||||
key: [
|
||||
max(ep_start.item(), min(ep_end.item() - 1, idx + delta))
|
||||
for delta in delta_idx
|
||||
]
|
||||
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
padding = { # Pad values outside of current episode range
|
||||
f"{key}_is_pad": torch.BoolTensor(
|
||||
[
|
||||
(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item())
|
||||
for delta in delta_idx
|
||||
]
|
||||
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
|
||||
)
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
@@ -763,9 +699,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if key not in self.meta.video_keys
|
||||
}
|
||||
|
||||
def _query_videos(
|
||||
self, query_timestamps: dict[str, list[float]], ep_idx: int
|
||||
) -> dict[str, torch.Tensor]:
|
||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
|
||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
|
||||
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
||||
@@ -774,9 +708,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
item = {}
|
||||
for vid_key, query_ts in query_timestamps.items():
|
||||
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||
frames = decode_video_frames(
|
||||
video_path, query_ts, self.tolerance_s, self.video_backend
|
||||
)
|
||||
frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
|
||||
item[vid_key] = frames.squeeze(0)
|
||||
|
||||
return item
|
||||
@@ -830,9 +762,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
|
||||
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||
current_ep_idx = (
|
||||
self.meta.total_episodes if episode_index is None else episode_index
|
||||
)
|
||||
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
|
||||
ep_buffer = {}
|
||||
# size and task are special cases that are not in self.features
|
||||
ep_buffer["size"] = 0
|
||||
@@ -841,17 +771,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
|
||||
return ep_buffer
|
||||
|
||||
def _get_image_file_path(
|
||||
self, episode_index: int, image_key: str, frame_index: int
|
||||
) -> Path:
|
||||
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
return self.root / fpath
|
||||
|
||||
def _save_image(
|
||||
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path
|
||||
) -> None:
|
||||
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
|
||||
if self.image_writer is None:
|
||||
if isinstance(image, torch.Tensor):
|
||||
image = image.cpu().numpy()
|
||||
@@ -877,9 +803,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
# Automatically add frame_index and timestamp to episode buffer
|
||||
frame_index = self.episode_buffer["size"]
|
||||
timestamp = (
|
||||
frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
|
||||
)
|
||||
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
|
||||
self.episode_buffer["frame_index"].append(frame_index)
|
||||
self.episode_buffer["timestamp"].append(timestamp)
|
||||
|
||||
@@ -930,9 +854,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episode_tasks = list(set(tasks))
|
||||
episode_index = episode_buffer["episode_index"]
|
||||
|
||||
episode_buffer["index"] = np.arange(
|
||||
self.meta.total_frames, self.meta.total_frames + episode_length
|
||||
)
|
||||
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
|
||||
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
|
||||
|
||||
# Add new tasks to the tasks dictionary
|
||||
@@ -942,9 +864,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.meta.add_task(task)
|
||||
|
||||
# Given tasks in natural language, find their corresponding task indices
|
||||
episode_buffer["task_index"] = np.array(
|
||||
[self.meta.get_task_index(task) for task in tasks]
|
||||
)
|
||||
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
|
||||
|
||||
for key, ft in self.features.items():
|
||||
# index, episode_index, task_index are already processed above, and image and video
|
||||
@@ -994,9 +914,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
|
||||
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
|
||||
ep_dataset = datasets.Dataset.from_dict(
|
||||
episode_dict, features=self.hf_features, split="train"
|
||||
)
|
||||
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
|
||||
ep_dataset = embed_images(ep_dataset)
|
||||
self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset])
|
||||
self.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
@@ -1115,9 +1033,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.delta_timestamps = None
|
||||
obj.delta_indices = None
|
||||
obj.episode_data_index = None
|
||||
obj.video_backend = (
|
||||
video_backend if video_backend is not None else get_safe_default_codec()
|
||||
)
|
||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
return obj
|
||||
|
||||
|
||||
@@ -1142,9 +1058,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME
|
||||
self.tolerances_s = (
|
||||
tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
|
||||
)
|
||||
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
|
||||
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
||||
# are handled by this class.
|
||||
self._datasets = [
|
||||
@@ -1223,13 +1137,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
def features(self) -> datasets.Features:
|
||||
features = {}
|
||||
for dataset in self._datasets:
|
||||
features.update(
|
||||
{
|
||||
k: v
|
||||
for k, v in dataset.hf_features.items()
|
||||
if k not in self.disabled_features
|
||||
}
|
||||
)
|
||||
features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
|
||||
return features
|
||||
|
||||
@property
|
||||
@@ -1290,9 +1198,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
continue
|
||||
break
|
||||
else:
|
||||
raise AssertionError(
|
||||
"We expect the loop to break out as long as the index is within bounds."
|
||||
)
|
||||
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
|
||||
item = self._datasets[dataset_idx][idx - start_idx]
|
||||
item["dataset_index"] = torch.tensor(dataset_idx)
|
||||
for data_key in self.disabled_features:
|
||||
|
||||
@@ -131,9 +131,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
else:
|
||||
self._delta_timestamps = None
|
||||
|
||||
def _make_data_spec(
|
||||
self, data_spec: dict[str, Any], buffer_capacity: int
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]:
|
||||
"""Makes the data spec for np.memmap."""
|
||||
if any(k.startswith("_") for k in data_spec):
|
||||
raise ValueError(
|
||||
@@ -208,9 +206,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
|
||||
# Shift the incoming indices if necessary.
|
||||
if self.num_frames > 0:
|
||||
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][
|
||||
next_index - 1
|
||||
]
|
||||
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
|
||||
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
|
||||
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
|
||||
data[OnlineBuffer.INDEX_KEY] += last_data_index + 1
|
||||
@@ -245,11 +241,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(
|
||||
np.unique(
|
||||
self._data[OnlineBuffer.EPISODE_INDEX_KEY][
|
||||
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]
|
||||
]
|
||||
)
|
||||
np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -287,9 +279,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY],
|
||||
)
|
||||
)[0]
|
||||
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][
|
||||
episode_data_indices
|
||||
]
|
||||
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices]
|
||||
|
||||
for data_key in self.delta_timestamps:
|
||||
# Note: The logic in this loop is copied from `load_previous_and_future_frames`.
|
||||
@@ -306,8 +296,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
|
||||
# Check violated query timestamps are all outside the episode range.
|
||||
assert (
|
||||
(query_ts[is_pad] < episode_timestamps[0])
|
||||
| (episode_timestamps[-1] < query_ts[is_pad])
|
||||
(query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad])
|
||||
).all(), (
|
||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}"
|
||||
") inside the episode range."
|
||||
@@ -322,9 +311,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
|
||||
def get_data_by_key(self, key: str) -> torch.Tensor:
|
||||
"""Returns all data for a given data key as a Tensor."""
|
||||
return torch.from_numpy(
|
||||
self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]
|
||||
)
|
||||
return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
|
||||
|
||||
|
||||
def compute_sampler_weights(
|
||||
@@ -355,19 +342,13 @@ def compute_sampler_weights(
|
||||
- Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not
|
||||
included here to avoid adding complexity.
|
||||
"""
|
||||
if len(offline_dataset) == 0 and (
|
||||
online_dataset is None or len(online_dataset) == 0
|
||||
):
|
||||
raise ValueError(
|
||||
"At least one of `offline_dataset` or `online_dataset` should be contain data."
|
||||
)
|
||||
if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0):
|
||||
raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.")
|
||||
if (online_dataset is None) ^ (online_sampling_ratio is None):
|
||||
raise ValueError(
|
||||
"`online_dataset` and `online_sampling_ratio` must be provided together or not at all."
|
||||
)
|
||||
offline_sampling_ratio = (
|
||||
0 if online_sampling_ratio is None else 1 - online_sampling_ratio
|
||||
)
|
||||
offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio
|
||||
|
||||
weights = []
|
||||
|
||||
|
||||
@@ -45,9 +45,7 @@ def concatenate_episodes(ep_dicts):
|
||||
return data_dict
|
||||
|
||||
|
||||
def save_images_concurrently(
|
||||
imgs_array: numpy.array, out_dir: Path, max_workers: int = 4
|
||||
):
|
||||
def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4):
|
||||
out_dir = Path(out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -57,10 +55,7 @@ def save_images_concurrently(
|
||||
|
||||
num_images = len(imgs_array)
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
[
|
||||
executor.submit(save_image, imgs_array[i], i, out_dir)
|
||||
for i in range(num_images)
|
||||
]
|
||||
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
|
||||
|
||||
|
||||
def get_default_encoding() -> dict:
|
||||
@@ -69,8 +64,7 @@ def get_default_encoding() -> dict:
|
||||
return {
|
||||
k: v.default
|
||||
for k, v in signature.parameters.items()
|
||||
if v.default is not inspect.Parameter.empty
|
||||
and k in ["vcodec", "pix_fmt", "g", "crf"]
|
||||
if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"]
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -58,9 +58,7 @@ class RandomSubsetApply(Transform):
|
||||
elif not isinstance(n_subset, int):
|
||||
raise TypeError("n_subset should be an int or None")
|
||||
elif not (1 <= n_subset <= len(transforms)):
|
||||
raise ValueError(
|
||||
f"n_subset should be in the interval [1, {len(transforms)}]"
|
||||
)
|
||||
raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]")
|
||||
|
||||
self.transforms = transforms
|
||||
total = sum(p)
|
||||
@@ -121,36 +119,26 @@ class SharpnessJitter(Transform):
|
||||
def _check_input(self, sharpness):
|
||||
if isinstance(sharpness, (int, float)):
|
||||
if sharpness < 0:
|
||||
raise ValueError(
|
||||
"If sharpness is a single number, it must be non negative."
|
||||
)
|
||||
raise ValueError("If sharpness is a single number, it must be non negative.")
|
||||
sharpness = [1.0 - sharpness, 1.0 + sharpness]
|
||||
sharpness[0] = max(sharpness[0], 0.0)
|
||||
elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2:
|
||||
sharpness = [float(v) for v in sharpness]
|
||||
else:
|
||||
raise TypeError(
|
||||
f"{sharpness=} should be a single number or a sequence with length 2."
|
||||
)
|
||||
raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.")
|
||||
|
||||
if not 0.0 <= sharpness[0] <= sharpness[1]:
|
||||
raise ValueError(
|
||||
f"sharpnesss values should be between (0., inf), but got {sharpness}."
|
||||
)
|
||||
raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.")
|
||||
|
||||
return float(sharpness[0]), float(sharpness[1])
|
||||
|
||||
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
|
||||
sharpness_factor = (
|
||||
torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
|
||||
)
|
||||
sharpness_factor = torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
|
||||
return {"sharpness_factor": sharpness_factor}
|
||||
|
||||
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
|
||||
sharpness_factor = params["sharpness_factor"]
|
||||
return self._call_kernel(
|
||||
F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor
|
||||
)
|
||||
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -52,15 +52,9 @@ STATS_PATH = "meta/stats.json"
|
||||
EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
TASKS_PATH = "meta/tasks.jsonl"
|
||||
|
||||
DEFAULT_VIDEO_PATH = (
|
||||
"videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
)
|
||||
DEFAULT_PARQUET_PATH = (
|
||||
"data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||
)
|
||||
DEFAULT_IMAGE_PATH = (
|
||||
"images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||
)
|
||||
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||
|
||||
DATASET_CARD_TEMPLATE = """
|
||||
---
|
||||
@@ -135,9 +129,7 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
||||
elif isinstance(value, (int, float)):
|
||||
serialized_dict[key] = value
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"The value '{value}' of type '{type(value)}' is not supported."
|
||||
)
|
||||
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
|
||||
return unflatten_dict(serialized_dict)
|
||||
|
||||
|
||||
@@ -216,10 +208,7 @@ def write_task(task_index: int, task: dict, local_dir: Path):
|
||||
|
||||
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||
tasks = load_jsonlines(local_dir / TASKS_PATH)
|
||||
tasks = {
|
||||
item["task_index"]: item["task"]
|
||||
for item in sorted(tasks, key=lambda x: x["task_index"])
|
||||
}
|
||||
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
|
||||
return tasks, task_to_task_index
|
||||
|
||||
@@ -230,10 +219,7 @@ def write_episode(episode: dict, local_dir: Path):
|
||||
|
||||
def load_episodes(local_dir: Path) -> dict:
|
||||
episodes = load_jsonlines(local_dir / EPISODES_PATH)
|
||||
return {
|
||||
item["episode_index"]: item
|
||||
for item in sorted(episodes, key=lambda x: x["episode_index"])
|
||||
}
|
||||
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
|
||||
|
||||
|
||||
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
|
||||
@@ -286,9 +272,7 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
elif first_item is None:
|
||||
pass
|
||||
else:
|
||||
items_dict[key] = [
|
||||
x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]
|
||||
]
|
||||
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
||||
return items_dict
|
||||
|
||||
|
||||
@@ -341,9 +325,7 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
|
||||
Otherwise, will throw a `CompatibilityError`.
|
||||
"""
|
||||
target_version = (
|
||||
packaging.version.parse(version)
|
||||
if not isinstance(version, packaging.version.Version)
|
||||
else version
|
||||
packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version
|
||||
)
|
||||
hub_versions = get_repo_versions(repo_id)
|
||||
|
||||
@@ -364,16 +346,12 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
|
||||
return f"v{target_version}"
|
||||
|
||||
compatibles = [
|
||||
v
|
||||
for v in hub_versions
|
||||
if v.major == target_version.major and v.minor <= target_version.minor
|
||||
v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor
|
||||
]
|
||||
if compatibles:
|
||||
return_version = max(compatibles)
|
||||
if return_version < target_version:
|
||||
logging.warning(
|
||||
f"Revision {version} for {repo_id} not found, using version v{return_version}"
|
||||
)
|
||||
logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}")
|
||||
return f"v{return_version}"
|
||||
|
||||
lower_major = [v for v in hub_versions if v.major < target_version.major]
|
||||
@@ -480,9 +458,7 @@ def create_empty_dataset_info(
|
||||
def get_episode_data_index(
|
||||
episode_dicts: dict[dict], episodes: list[int] | None = None
|
||||
) -> dict[str, torch.Tensor]:
|
||||
episode_lengths = {
|
||||
ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()
|
||||
}
|
||||
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
|
||||
if episodes is not None:
|
||||
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
|
||||
|
||||
@@ -532,9 +508,7 @@ def check_timestamps_sync(
|
||||
|
||||
# Mask to ignore differences at the boundaries between episodes
|
||||
mask = np.ones(len(diffs), dtype=bool)
|
||||
ignored_diffs = (
|
||||
episode_data_index["to"][:-1] - 1
|
||||
) # indices at the end of each episode
|
||||
ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode
|
||||
mask[ignored_diffs] = False
|
||||
filtered_within_tolerance = within_tolerance[mask]
|
||||
|
||||
@@ -580,14 +554,10 @@ def check_delta_timestamps(
|
||||
"""
|
||||
outside_tolerance = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
within_tolerance = [
|
||||
abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts
|
||||
]
|
||||
within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts]
|
||||
if not all(within_tolerance):
|
||||
outside_tolerance[key] = [
|
||||
ts
|
||||
for ts, is_within in zip(delta_ts, within_tolerance, strict=True)
|
||||
if not is_within
|
||||
ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within
|
||||
]
|
||||
|
||||
if len(outside_tolerance) > 0:
|
||||
@@ -605,9 +575,7 @@ def check_delta_timestamps(
|
||||
return True
|
||||
|
||||
|
||||
def get_delta_indices(
|
||||
delta_timestamps: dict[str, list[float]], fps: int
|
||||
) -> dict[str, list[int]]:
|
||||
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
|
||||
delta_indices = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
delta_indices[key] = [round(d * fps) for d in delta_ts]
|
||||
@@ -672,9 +640,7 @@ def create_lerobot_dataset_card(
|
||||
],
|
||||
)
|
||||
|
||||
card_template = (
|
||||
importlib.resources.files("lerobot.common.datasets") / "card_template.md"
|
||||
).read_text()
|
||||
card_template = (importlib.resources.files("lerobot.common.datasets") / "card_template.md").read_text()
|
||||
|
||||
return DatasetCard.from_template(
|
||||
card_data=card_data,
|
||||
@@ -743,18 +709,14 @@ def validate_frame(frame: dict, features: dict):
|
||||
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
|
||||
actual_features = set(frame.keys())
|
||||
|
||||
error_message = validate_features_presence(
|
||||
actual_features, expected_features, optional_features
|
||||
)
|
||||
error_message = validate_features_presence(actual_features, expected_features, optional_features)
|
||||
|
||||
if "task" in frame:
|
||||
error_message += validate_feature_string("task", frame["task"])
|
||||
|
||||
common_features = actual_features & (expected_features | optional_features)
|
||||
for name in common_features - {"task"}:
|
||||
error_message += validate_feature_dtype_and_shape(
|
||||
name, features[name], frame[name]
|
||||
)
|
||||
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
|
||||
|
||||
if error_message:
|
||||
raise ValueError(error_message)
|
||||
@@ -777,9 +739,7 @@ def validate_features_presence(
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_dtype_and_shape(
|
||||
name: str, feature: dict, value: np.ndarray | PILImage.Image | str
|
||||
):
|
||||
def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
|
||||
expected_dtype = feature["dtype"]
|
||||
expected_shape = feature["shape"]
|
||||
if is_valid_numpy_dtype_string(expected_dtype):
|
||||
@@ -789,9 +749,7 @@ def validate_feature_dtype_and_shape(
|
||||
elif expected_dtype == "string":
|
||||
return validate_feature_string(name, value)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"The feature dtype '{expected_dtype}' is not implemented yet."
|
||||
)
|
||||
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
||||
|
||||
|
||||
def validate_feature_numpy_array(
|
||||
@@ -813,17 +771,13 @@ def validate_feature_numpy_array(
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_image_or_video(
|
||||
name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image
|
||||
):
|
||||
def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
|
||||
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_shape = value.shape
|
||||
c, h, w = expected_shape
|
||||
if len(actual_shape) != 3 or (
|
||||
actual_shape != (c, h, w) and actual_shape != (h, w, c)
|
||||
):
|
||||
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
|
||||
elif isinstance(value, PILImage.Image):
|
||||
pass
|
||||
@@ -854,9 +808,7 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
|
||||
)
|
||||
|
||||
if episode_buffer["size"] == 0:
|
||||
raise ValueError(
|
||||
"You must add one or several frames with `add_frame` before calling `add_episode`."
|
||||
)
|
||||
raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
|
||||
|
||||
buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
|
||||
if not buffer_keys == set(features):
|
||||
|
||||
@@ -218,9 +218,7 @@ def get_features_from_hf_dataset(
|
||||
dtype = ft.feature.dtype
|
||||
shape = (ft.length,)
|
||||
motor_names = (
|
||||
robot_config["names"][key]
|
||||
if robot_config
|
||||
else [f"motor_{i}" for i in range(ft.length)]
|
||||
robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)]
|
||||
)
|
||||
assert len(motor_names) == shape[0]
|
||||
names = {"motors": motor_names}
|
||||
@@ -244,15 +242,11 @@ def get_features_from_hf_dataset(
|
||||
return features
|
||||
|
||||
|
||||
def add_task_index_by_episodes(
|
||||
dataset: Dataset, tasks_by_episodes: dict
|
||||
) -> tuple[Dataset, list[str]]:
|
||||
def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
|
||||
df = dataset.to_pandas()
|
||||
tasks = list(set(tasks_by_episodes.values()))
|
||||
tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)}
|
||||
episodes_to_task_index = {
|
||||
ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()
|
||||
}
|
||||
episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int)
|
||||
|
||||
features = dataset.features
|
||||
@@ -269,19 +263,10 @@ def add_task_index_from_tasks_col(
|
||||
# HACK: This is to clean some of the instructions in our version of Open X datasets
|
||||
prefix_to_clean = "tf.Tensor(b'"
|
||||
suffix_to_clean = "', shape=(), dtype=string)"
|
||||
df[tasks_col] = (
|
||||
df[tasks_col]
|
||||
.str.removeprefix(prefix_to_clean)
|
||||
.str.removesuffix(suffix_to_clean)
|
||||
)
|
||||
df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean)
|
||||
|
||||
# Create task_index col
|
||||
tasks_by_episode = (
|
||||
df.groupby("episode_index")[tasks_col]
|
||||
.unique()
|
||||
.apply(lambda x: x.tolist())
|
||||
.to_dict()
|
||||
)
|
||||
tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict()
|
||||
tasks = df[tasks_col].unique().tolist()
|
||||
tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)}
|
||||
df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int)
|
||||
@@ -306,9 +291,7 @@ def split_parquet_by_episodes(
|
||||
for ep_chunk in range(total_chunks):
|
||||
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
|
||||
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
||||
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(
|
||||
episode_chunk=ep_chunk
|
||||
)
|
||||
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
|
||||
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
||||
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
@@ -340,9 +323,7 @@ def move_videos(
|
||||
videos_moved = False
|
||||
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
|
||||
if len(video_files) == 0:
|
||||
video_files = [
|
||||
str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")
|
||||
]
|
||||
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")]
|
||||
videos_moved = True # Videos have already been moved
|
||||
|
||||
assert len(video_files) == total_episodes * len(video_keys)
|
||||
@@ -373,9 +354,7 @@ def move_videos(
|
||||
target_path = DEFAULT_VIDEO_PATH.format(
|
||||
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
|
||||
)
|
||||
video_file = V1_VIDEO_FILE.format(
|
||||
video_key=vid_key, episode_index=ep_idx
|
||||
)
|
||||
video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
|
||||
if len(video_dirs) == 1:
|
||||
video_path = video_dirs[0] / video_file
|
||||
else:
|
||||
@@ -392,9 +371,7 @@ def move_videos(
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
def fix_lfs_video_files_tracking(
|
||||
work_dir: Path, lfs_untracked_videos: list[str]
|
||||
) -> None:
|
||||
def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None:
|
||||
"""
|
||||
HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
|
||||
there's no other option than to download the actual files and reupload them with lfs tracking.
|
||||
@@ -418,14 +395,10 @@ def fix_lfs_video_files_tracking(
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
def fix_gitattributes(
|
||||
work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path
|
||||
) -> None:
|
||||
def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None:
|
||||
shutil.copyfile(clean_gittatributes, current_gittatributes)
|
||||
subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
|
||||
subprocess.run(
|
||||
["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True
|
||||
)
|
||||
subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True)
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
@@ -462,9 +435,7 @@ def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[st
|
||||
return [f for f in video_files if f not in lfs_tracked_files]
|
||||
|
||||
|
||||
def get_videos_info(
|
||||
repo_id: str, local_dir: Path, video_keys: list[str], branch: str
|
||||
) -> dict:
|
||||
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
|
||||
# Assumes first episode
|
||||
video_files = [
|
||||
DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
|
||||
@@ -539,31 +510,19 @@ def convert_dataset(
|
||||
if single_task:
|
||||
tasks_by_episodes = {ep_idx: single_task for ep_idx in episode_indices}
|
||||
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
||||
tasks_by_episodes = {
|
||||
ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()
|
||||
}
|
||||
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
elif tasks_path:
|
||||
tasks_by_episodes = load_json(tasks_path)
|
||||
tasks_by_episodes = {
|
||||
int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()
|
||||
}
|
||||
tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
|
||||
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
||||
tasks_by_episodes = {
|
||||
ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()
|
||||
}
|
||||
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
elif tasks_col:
|
||||
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(
|
||||
dataset, tasks_col
|
||||
)
|
||||
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
assert set(tasks) == {
|
||||
task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks
|
||||
}
|
||||
tasks = [
|
||||
{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)
|
||||
]
|
||||
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
|
||||
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
|
||||
write_jsonlines(tasks, v20_dir / TASKS_PATH)
|
||||
features["task_index"] = {
|
||||
"dtype": "int64",
|
||||
@@ -593,9 +552,7 @@ def convert_dataset(
|
||||
clean_gitattr,
|
||||
branch,
|
||||
)
|
||||
videos_info = get_videos_info(
|
||||
repo_id, v1x_dir, video_keys=video_keys, branch=branch
|
||||
)
|
||||
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
|
||||
for key in video_keys:
|
||||
features[key]["shape"] = (
|
||||
videos_info[key].pop("video.height"),
|
||||
@@ -603,22 +560,15 @@ def convert_dataset(
|
||||
videos_info[key].pop("video.channels"),
|
||||
)
|
||||
features[key]["video_info"] = videos_info[key]
|
||||
assert math.isclose(
|
||||
videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3
|
||||
)
|
||||
assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
|
||||
if "encoding" in metadata_v1:
|
||||
assert (
|
||||
videos_info[key]["video.pix_fmt"]
|
||||
== metadata_v1["encoding"]["pix_fmt"]
|
||||
)
|
||||
assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
|
||||
else:
|
||||
assert metadata_v1.get("video", 0) == 0
|
||||
videos_info = None
|
||||
|
||||
# Split data into 1 parquet file by episode
|
||||
episode_lengths = split_parquet_by_episodes(
|
||||
dataset, total_episodes, total_chunks, v20_dir
|
||||
)
|
||||
episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
|
||||
|
||||
if robot_config is not None:
|
||||
robot_type = robot_config.type
|
||||
@@ -656,14 +606,10 @@ def convert_dataset(
|
||||
}
|
||||
write_json(metadata_v2_0, v20_dir / INFO_PATH)
|
||||
convert_stats_to_json(v1x_dir, v20_dir)
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs
|
||||
)
|
||||
card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(
|
||||
repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch
|
||||
)
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(
|
||||
@@ -674,9 +620,7 @@ def convert_dataset(
|
||||
)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(
|
||||
repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch
|
||||
)
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
|
||||
|
||||
hub_api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
|
||||
@@ -35,30 +35,22 @@ def fix_dataset(repo_id: str) -> str:
|
||||
|
||||
dataset_info = get_dataset_config_info(repo_id, "default")
|
||||
with SuppressWarnings():
|
||||
lerobot_metadata = LeRobotDatasetMetadata(
|
||||
repo_id, revision=V20, force_cache_sync=True
|
||||
)
|
||||
lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True)
|
||||
|
||||
meta_features = {
|
||||
key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"
|
||||
}
|
||||
meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"}
|
||||
parquet_features = set(dataset_info.features)
|
||||
|
||||
diff_parquet_meta = parquet_features - meta_features
|
||||
diff_meta_parquet = meta_features - parquet_features
|
||||
|
||||
if diff_parquet_meta:
|
||||
raise ValueError(
|
||||
f"In parquet not in info.json: {parquet_features - meta_features}"
|
||||
)
|
||||
raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}")
|
||||
|
||||
if not diff_meta_parquet:
|
||||
return f"{repo_id}: skipped (no diff)"
|
||||
|
||||
if diff_meta_parquet:
|
||||
logging.warning(
|
||||
f"In info.json not in parquet: {meta_features - parquet_features}"
|
||||
)
|
||||
logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
|
||||
assert diff_meta_parquet == {"language_instruction"}
|
||||
lerobot_metadata.features.pop("language_instruction")
|
||||
write_info(lerobot_metadata.info, lerobot_metadata.root)
|
||||
|
||||
@@ -99,9 +99,7 @@ def convert_dataset(
|
||||
repo_type="dataset",
|
||||
)
|
||||
|
||||
hub_api.create_tag(
|
||||
repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset"
|
||||
)
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -26,9 +26,7 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import write_episode_stats
|
||||
|
||||
|
||||
def sample_episode_video_frames(
|
||||
dataset: LeRobotDataset, episode_index: int, ft_key: str
|
||||
) -> np.ndarray:
|
||||
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
|
||||
ep_len = dataset.meta.episodes[episode_index]["length"]
|
||||
sampled_indices = sample_indices(ep_len)
|
||||
query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
|
||||
@@ -51,14 +49,11 @@ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
||||
|
||||
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
|
||||
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
|
||||
ep_stats[key] = get_feature_stats(
|
||||
ep_ft_data, axis=axes_to_reduce, keepdims=keepdims
|
||||
)
|
||||
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
|
||||
|
||||
if ft["dtype"] in ["image", "video"]: # remove batch dim
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v, axis=0)
|
||||
for k, v in ep_stats[key].items()
|
||||
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
dataset.meta.episodes_stats[ep_idx] = ep_stats
|
||||
|
||||
@@ -65,9 +65,7 @@ def decode_video_frames(
|
||||
if backend == "torchcodec":
|
||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
||||
elif backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(
|
||||
video_path, timestamps, tolerance_s, backend
|
||||
)
|
||||
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
else:
|
||||
raise ValueError(f"Unsupported video backend: {backend}")
|
||||
|
||||
@@ -346,9 +344,7 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(
|
||||
ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||
)
|
||||
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
@@ -362,9 +358,7 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
"has_audio": True,
|
||||
"audio.channels": audio_stream_info.get("channels", None),
|
||||
"audio.codec": audio_stream_info.get("codec_name", None),
|
||||
"audio.bit_rate": int(audio_stream_info["bit_rate"])
|
||||
if audio_stream_info.get("bit_rate")
|
||||
else None,
|
||||
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
|
||||
"audio.sample_rate": int(audio_stream_info["sample_rate"])
|
||||
if audio_stream_info.get("sample_rate")
|
||||
else None,
|
||||
@@ -386,9 +380,7 @@ def get_video_info(video_path: Path | str) -> dict:
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(
|
||||
ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||
)
|
||||
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
|
||||
@@ -61,16 +61,10 @@ class AlohaEnv(EnvConfig):
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels":
|
||||
self.features["top"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(480, 640, 3)
|
||||
)
|
||||
self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
|
||||
elif self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(
|
||||
type=FeatureType.STATE, shape=(14,)
|
||||
)
|
||||
self.features["pixels/top"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(480, 640, 3)
|
||||
)
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,))
|
||||
self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
@@ -108,13 +102,9 @@ class PushtEnv(EnvConfig):
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels_agent_pos":
|
||||
self.features["pixels"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(384, 384, 3)
|
||||
)
|
||||
self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3))
|
||||
elif self.obs_type == "environment_state_agent_pos":
|
||||
self.features["environment_state"] = PolicyFeature(
|
||||
type=FeatureType.ENV, shape=(16,)
|
||||
)
|
||||
self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,))
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
@@ -153,9 +143,7 @@ class XarmEnv(EnvConfig):
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(
|
||||
type=FeatureType.STATE, shape=(4,)
|
||||
)
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
|
||||
@@ -32,9 +32,7 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
|
||||
|
||||
def make_env(
|
||||
cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False
|
||||
) -> gym.vector.VectorEnv | None:
|
||||
def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> gym.vector.VectorEnv | None:
|
||||
"""Makes a gym vector environment according to the config.
|
||||
|
||||
Args:
|
||||
@@ -58,9 +56,7 @@ def make_env(
|
||||
try:
|
||||
importlib.import_module(package_name)
|
||||
except ModuleNotFoundError as e:
|
||||
print(
|
||||
f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`"
|
||||
)
|
||||
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
|
||||
raise e
|
||||
|
||||
gym_handle = f"{package_name}/{cfg.task}"
|
||||
@@ -68,18 +64,13 @@ def make_env(
|
||||
# batched version of the env that returns an observation of shape (b, c)
|
||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||
env = env_cls(
|
||||
[
|
||||
lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs)
|
||||
for _ in range(n_envs)
|
||||
]
|
||||
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
|
||||
)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def make_maniskill_env(
|
||||
cfg: DictConfig, n_envs: int | None = None
|
||||
) -> gym.vector.VectorEnv | None:
|
||||
def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
|
||||
"""Make ManiSkill3 gym environment"""
|
||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
|
||||
@@ -96,9 +87,7 @@ def make_maniskill_env(
|
||||
# state should have the size of 25
|
||||
# env = ConvertToLeRobotEnv(env, n_envs)
|
||||
# env = PixelWrapper(cfg, env, n_envs)
|
||||
env._max_episode_steps = env.max_episode_steps = (
|
||||
50 # gym_utils.find_max_episode_steps_value(env)
|
||||
)
|
||||
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
|
||||
env.unwrapped.metadata["render_fps"] = 20
|
||||
|
||||
return env
|
||||
@@ -125,11 +114,7 @@ class PixelWrapper(gym.Wrapper):
|
||||
def _get_obs(self, obs):
|
||||
frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2)
|
||||
self._frames.append(frame)
|
||||
return {
|
||||
"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(
|
||||
self.env.device
|
||||
)
|
||||
}
|
||||
return {"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(self.env.device)}
|
||||
|
||||
def reset(self, seed):
|
||||
obs, info = self.env.reset() # (seed=seed)
|
||||
@@ -164,9 +149,7 @@ class ConvertToLeRobotEnv(gym.Wrapper):
|
||||
|
||||
images = torch.concat(images, axis=-1)
|
||||
# flatten the rest of the data which should just be state data
|
||||
observation = common.flatten_state_dict(
|
||||
observation, use_torch=True, device=self.base_env.device
|
||||
)
|
||||
observation = common.flatten_state_dict(observation, use_torch=True, device=self.base_env.device)
|
||||
ret = dict()
|
||||
ret["state"] = observation
|
||||
ret["pixels"] = images
|
||||
|
||||
@@ -46,9 +46,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, (
|
||||
f"expect channel last images, but instead got {img.shape=}"
|
||||
)
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
@@ -81,9 +79,7 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||
for key, ft in env_cfg.features.items():
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
if len(ft.shape) != 3:
|
||||
raise ValueError(
|
||||
f"Number of dimensions of {key} != 3 (shape={ft.shape})"
|
||||
)
|
||||
raise ValueError(f"Number of dimensions of {key} != 3 (shape={ft.shape})")
|
||||
|
||||
shape = get_channel_first_image_shape(ft.shape)
|
||||
feature = PolicyFeature(type=ft.type, shape=shape)
|
||||
|
||||
@@ -34,13 +34,7 @@ def make_optimizer_and_scheduler(
|
||||
Returns:
|
||||
tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`.
|
||||
"""
|
||||
params = (
|
||||
policy.get_optim_params()
|
||||
if cfg.use_policy_training_preset
|
||||
else policy.parameters()
|
||||
)
|
||||
params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
|
||||
optimizer = cfg.optimizer.build(params)
|
||||
lr_scheduler = (
|
||||
cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
|
||||
)
|
||||
lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
|
||||
return optimizer, lr_scheduler
|
||||
|
||||
@@ -102,9 +102,7 @@ def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> No
|
||||
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
|
||||
|
||||
|
||||
def load_optimizer_state(
|
||||
optimizer: torch.optim.Optimizer, save_dir: Path
|
||||
) -> torch.optim.Optimizer:
|
||||
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
|
||||
current_state_dict = optimizer.state_dict()
|
||||
flat_state = load_file(save_dir / OPTIMIZER_STATE)
|
||||
state = unflatten_dict(flat_state)
|
||||
|
||||
@@ -36,9 +36,7 @@ class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@abc.abstractmethod
|
||||
def build(
|
||||
self, optimizer: Optimizer, num_training_steps: int
|
||||
) -> LRScheduler | None:
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler | None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -79,11 +77,7 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
|
||||
)
|
||||
return max(
|
||||
0.0,
|
||||
0.5
|
||||
* (
|
||||
1.0
|
||||
+ math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)
|
||||
),
|
||||
0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)),
|
||||
)
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
@@ -111,9 +105,7 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
|
||||
def cosine_decay_schedule(current_step):
|
||||
step = min(current_step, self.num_decay_steps)
|
||||
cosine_decay = 0.5 * (
|
||||
1 + math.cos(math.pi * step / self.num_decay_steps)
|
||||
)
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
|
||||
alpha = self.decay_lr / self.peak_lr
|
||||
decayed = (1 - alpha) * cosine_decay + alpha
|
||||
return decayed
|
||||
@@ -132,8 +124,6 @@ def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
|
||||
|
||||
|
||||
def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler:
|
||||
state_dict = deserialize_json_into_object(
|
||||
save_dir / SCHEDULER_STATE, scheduler.state_dict()
|
||||
)
|
||||
state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict())
|
||||
scheduler.load_state_dict(state_dict)
|
||||
return scheduler
|
||||
|
||||
@@ -171,9 +171,7 @@ class ACTConfig(PreTrainedConfig):
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.image_features and not self.env_state_feature:
|
||||
raise ValueError(
|
||||
"You must provide at least one image or the environment state among the inputs."
|
||||
)
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
|
||||
@@ -63,9 +63,7 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
@@ -76,9 +74,7 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
self.model = ACT(config)
|
||||
|
||||
if config.temporal_ensemble_coeff is not None:
|
||||
self.temporal_ensembler = ACTTemporalEnsembler(
|
||||
config.temporal_ensemble_coeff, config.chunk_size
|
||||
)
|
||||
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
|
||||
|
||||
self.reset()
|
||||
|
||||
@@ -122,12 +118,8 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = [
|
||||
batch[key] for key in self.config.image_features
|
||||
]
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = [batch[key] for key in self.config.image_features]
|
||||
|
||||
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
||||
# we are ensembling over.
|
||||
@@ -154,19 +146,14 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = [
|
||||
batch[key] for key in self.config.image_features
|
||||
]
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = [batch[key] for key in self.config.image_features]
|
||||
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
l1_loss = (
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none")
|
||||
* ~batch["action_is_pad"].unsqueeze(-1)
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
).mean()
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss.item()}
|
||||
@@ -176,12 +163,7 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
# KL-divergence per batch element, then take the mean over the batch.
|
||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||
mean_kld = (
|
||||
(
|
||||
-0.5
|
||||
* (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())
|
||||
)
|
||||
.sum(-1)
|
||||
.mean()
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
)
|
||||
loss_dict["kld_loss"] = mean_kld.item()
|
||||
loss = l1_loss + mean_kld * self.config.kl_weight
|
||||
@@ -235,9 +217,7 @@ class ACTTemporalEnsembler:
|
||||
```
|
||||
"""
|
||||
self.chunk_size = chunk_size
|
||||
self.ensemble_weights = torch.exp(
|
||||
-temporal_ensemble_coeff * torch.arange(chunk_size)
|
||||
)
|
||||
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
|
||||
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
|
||||
self.reset()
|
||||
|
||||
@@ -253,9 +233,7 @@ class ACTTemporalEnsembler:
|
||||
time steps, and pop/return the next batch of actions in the sequence.
|
||||
"""
|
||||
self.ensemble_weights = self.ensemble_weights.to(device=actions.device)
|
||||
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(
|
||||
device=actions.device
|
||||
)
|
||||
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device)
|
||||
if self.ensembled_actions is None:
|
||||
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
|
||||
# time step of the episode.
|
||||
@@ -270,22 +248,12 @@ class ACTTemporalEnsembler:
|
||||
else:
|
||||
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
|
||||
# the online update for those entries.
|
||||
self.ensembled_actions *= self.ensemble_weights_cumsum[
|
||||
self.ensembled_actions_count - 1
|
||||
]
|
||||
self.ensembled_actions += (
|
||||
actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
|
||||
)
|
||||
self.ensembled_actions /= self.ensemble_weights_cumsum[
|
||||
self.ensembled_actions_count
|
||||
]
|
||||
self.ensembled_actions_count = torch.clamp(
|
||||
self.ensembled_actions_count + 1, max=self.chunk_size
|
||||
)
|
||||
self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1]
|
||||
self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
|
||||
self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count]
|
||||
self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size)
|
||||
# The last action, which has no prior online average, needs to get concatenated onto the end.
|
||||
self.ensembled_actions = torch.cat(
|
||||
[self.ensembled_actions, actions[:, -1:]], dim=1
|
||||
)
|
||||
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
|
||||
self.ensembled_actions_count = torch.cat(
|
||||
[
|
||||
self.ensembled_actions_count,
|
||||
@@ -356,9 +324,7 @@ class ACT(nn.Module):
|
||||
config.dim_model,
|
||||
)
|
||||
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
||||
self.vae_encoder_latent_output_proj = nn.Linear(
|
||||
config.dim_model, config.latent_dim * 2
|
||||
)
|
||||
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
|
||||
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
|
||||
# dimension.
|
||||
num_input_token_encoder = 1 + config.chunk_size
|
||||
@@ -366,9 +332,7 @@ class ACT(nn.Module):
|
||||
num_input_token_encoder += 1
|
||||
self.register_buffer(
|
||||
"vae_encoder_pos_enc",
|
||||
create_sinusoidal_pos_embedding(
|
||||
num_input_token_encoder, config.dim_model
|
||||
).unsqueeze(0),
|
||||
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
|
||||
)
|
||||
|
||||
# Backbone for image feature extraction.
|
||||
@@ -385,9 +349,7 @@ class ACT(nn.Module):
|
||||
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
|
||||
# feature map).
|
||||
# Note: The forward method of this returns a dict: {"feature_map": output}.
|
||||
self.backbone = IntermediateLayerGetter(
|
||||
backbone_model, return_layers={"layer4": "feature_map"}
|
||||
)
|
||||
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
|
||||
|
||||
# Transformer (acts as VAE decoder when training with the variational objective).
|
||||
self.encoder = ACTEncoder(config)
|
||||
@@ -416,18 +378,14 @@ class ACT(nn.Module):
|
||||
n_1d_tokens += 1
|
||||
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
|
||||
if self.config.image_features:
|
||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(
|
||||
config.dim_model // 2
|
||||
)
|
||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||
|
||||
# Transformer decoder.
|
||||
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
|
||||
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)
|
||||
|
||||
# Final action regression head on the output of the transformer's decoder.
|
||||
self.action_head = nn.Linear(
|
||||
config.dim_model, self.config.action_feature.shape[0]
|
||||
)
|
||||
self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0])
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
@@ -437,9 +395,7 @@ class ACT(nn.Module):
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(
|
||||
self, batch: dict[str, Tensor]
|
||||
) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
|
||||
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
|
||||
|
||||
`batch` should have the following structure:
|
||||
@@ -475,13 +431,9 @@ class ACT(nn.Module):
|
||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||
) # (B, 1, D)
|
||||
if self.config.robot_state_feature:
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(
|
||||
batch["observation.state"]
|
||||
)
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
||||
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(
|
||||
batch["action"]
|
||||
) # (B, S, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||
|
||||
if self.config.robot_state_feature:
|
||||
vae_encoder_input = [
|
||||
@@ -526,26 +478,20 @@ class ACT(nn.Module):
|
||||
# When not using the VAE encoder, we set the latent to be all zeros.
|
||||
mu = log_sigma_x2 = None
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||
latent_sample = torch.zeros(
|
||||
[batch_size, self.config.latent_dim], dtype=torch.float32
|
||||
).to(batch["observation.state"].device)
|
||||
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
|
||||
batch["observation.state"].device
|
||||
)
|
||||
|
||||
# Prepare transformer encoder inputs.
|
||||
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
|
||||
encoder_in_pos_embed = list(
|
||||
self.encoder_1d_feature_pos_embed.weight.unsqueeze(1)
|
||||
)
|
||||
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
|
||||
# Robot state token.
|
||||
if self.config.robot_state_feature:
|
||||
encoder_in_tokens.append(
|
||||
self.encoder_robot_state_input_proj(batch["observation.state"])
|
||||
)
|
||||
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
|
||||
# Environment state token.
|
||||
if self.config.env_state_feature:
|
||||
encoder_in_tokens.append(
|
||||
self.encoder_env_state_input_proj(
|
||||
batch["observation.environment_state"]
|
||||
)
|
||||
self.encoder_env_state_input_proj(batch["observation.environment_state"])
|
||||
)
|
||||
|
||||
# Camera observation features and positional embeddings.
|
||||
@@ -556,9 +502,7 @@ class ACT(nn.Module):
|
||||
# For a list of images, the H and W may vary but H*W is constant.
|
||||
for img in batch["observation.images"]:
|
||||
cam_features = self.backbone(img)["feature_map"]
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(
|
||||
dtype=cam_features.dtype
|
||||
)
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features)
|
||||
|
||||
# Rearrange features to (sequence, batch, dim).
|
||||
@@ -604,14 +548,8 @@ class ACTEncoder(nn.Module):
|
||||
def __init__(self, config: ACTConfig, is_vae_encoder: bool = False):
|
||||
super().__init__()
|
||||
self.is_vae_encoder = is_vae_encoder
|
||||
num_layers = (
|
||||
config.n_vae_encoder_layers
|
||||
if self.is_vae_encoder
|
||||
else config.n_encoder_layers
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[ACTEncoderLayer(config) for _ in range(num_layers)]
|
||||
)
|
||||
num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers
|
||||
self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)])
|
||||
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
|
||||
|
||||
def forward(
|
||||
@@ -629,9 +567,7 @@ class ACTEncoder(nn.Module):
|
||||
class ACTEncoderLayer(nn.Module):
|
||||
def __init__(self, config: ACTConfig):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(
|
||||
config.dim_model, config.n_heads, dropout=config.dropout
|
||||
)
|
||||
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
|
||||
|
||||
# Feed forward layers.
|
||||
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
|
||||
@@ -646,9 +582,7 @@ class ACTEncoderLayer(nn.Module):
|
||||
self.activation = get_activation_fn(config.feedforward_activation)
|
||||
self.pre_norm = config.pre_norm
|
||||
|
||||
def forward(
|
||||
self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
|
||||
) -> Tensor:
|
||||
def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor:
|
||||
skip = x
|
||||
if self.pre_norm:
|
||||
x = self.norm1(x)
|
||||
@@ -673,9 +607,7 @@ class ACTDecoder(nn.Module):
|
||||
def __init__(self, config: ACTConfig):
|
||||
"""Convenience module for running multiple decoder layers followed by normalization."""
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)]
|
||||
)
|
||||
self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)])
|
||||
self.norm = nn.LayerNorm(config.dim_model)
|
||||
|
||||
def forward(
|
||||
@@ -700,12 +632,8 @@ class ACTDecoder(nn.Module):
|
||||
class ACTDecoderLayer(nn.Module):
|
||||
def __init__(self, config: ACTConfig):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(
|
||||
config.dim_model, config.n_heads, dropout=config.dropout
|
||||
)
|
||||
self.multihead_attn = nn.MultiheadAttention(
|
||||
config.dim_model, config.n_heads, dropout=config.dropout
|
||||
)
|
||||
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
|
||||
self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
|
||||
|
||||
# Feed forward layers.
|
||||
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
|
||||
@@ -746,9 +674,7 @@ class ACTDecoderLayer(nn.Module):
|
||||
if self.pre_norm:
|
||||
x = self.norm1(x)
|
||||
q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
|
||||
x = self.self_attn(q, k, value=x)[
|
||||
0
|
||||
] # select just the output, not the attention weights
|
||||
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
|
||||
x = skip + self.dropout1(x)
|
||||
if self.pre_norm:
|
||||
skip = x
|
||||
@@ -785,14 +711,9 @@ def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tenso
|
||||
"""
|
||||
|
||||
def get_position_angle_vec(position):
|
||||
return [
|
||||
position / np.power(10000, 2 * (hid_j // 2) / dimension)
|
||||
for hid_j in range(dimension)
|
||||
]
|
||||
return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)]
|
||||
|
||||
sinusoid_table = np.array(
|
||||
[get_position_angle_vec(pos_i) for pos_i in range(num_positions)]
|
||||
)
|
||||
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)])
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
return torch.from_numpy(sinusoid_table).float()
|
||||
@@ -837,9 +758,7 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
|
||||
x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
|
||||
|
||||
inverse_frequency = self._temperature ** (
|
||||
2
|
||||
* (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2)
|
||||
/ self.dimension
|
||||
2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
|
||||
)
|
||||
|
||||
x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
|
||||
@@ -847,15 +766,9 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
|
||||
|
||||
# Note: this stack then flatten operation results in interleaved sine and cosine terms.
|
||||
# pos_embed_x and pos_embed_y are (1, H, W, C // 2).
|
||||
pos_embed_x = torch.stack(
|
||||
(x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1
|
||||
).flatten(3)
|
||||
pos_embed_y = torch.stack(
|
||||
(y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1
|
||||
).flatten(3)
|
||||
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(
|
||||
0, 3, 1, 2
|
||||
) # (1, C, H, W)
|
||||
pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3)
|
||||
pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3)
|
||||
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W)
|
||||
|
||||
return pos_embed
|
||||
|
||||
|
||||
@@ -205,16 +205,11 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if len(self.image_features) == 0 and self.env_state_feature is None:
|
||||
raise ValueError(
|
||||
"You must provide at least one image or the environment state among the inputs."
|
||||
)
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
if self.crop_shape is not None:
|
||||
for key, image_ft in self.image_features.items():
|
||||
if (
|
||||
self.crop_shape[0] > image_ft.shape[1]
|
||||
or self.crop_shape[1] > image_ft.shape[2]
|
||||
):
|
||||
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for "
|
||||
|
||||
@@ -70,9 +70,7 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
@@ -99,9 +97,7 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
if self.config.image_features:
|
||||
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
|
||||
if self.config.env_state_feature:
|
||||
self._queues["observation.environment_state"] = deque(
|
||||
maxlen=self.config.n_obs_steps
|
||||
)
|
||||
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
@@ -127,9 +123,7 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
@@ -138,11 +132,7 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
# stack n latest observations from the queue
|
||||
batch = {
|
||||
k: torch.stack(list(self._queues[k]), dim=1)
|
||||
for k in batch
|
||||
if k in self._queues
|
||||
}
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
|
||||
# TODO(rcadene): make above methods return output dictionary?
|
||||
@@ -157,9 +147,7 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
@@ -201,9 +189,7 @@ class DiffusionModel(nn.Module):
|
||||
if self.config.env_state_feature:
|
||||
global_cond_dim += self.config.env_state_feature.shape[0]
|
||||
|
||||
self.unet = DiffusionConditionalUnet1d(
|
||||
config, global_cond_dim=global_cond_dim * config.n_obs_steps
|
||||
)
|
||||
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
|
||||
|
||||
self.noise_scheduler = _make_noise_scheduler(
|
||||
config.noise_scheduler_type,
|
||||
@@ -249,9 +235,7 @@ class DiffusionModel(nn.Module):
|
||||
global_cond=global_cond,
|
||||
)
|
||||
# Compute previous image: x_t -> x_t-1
|
||||
sample = self.noise_scheduler.step(
|
||||
model_output, t, sample, generator=generator
|
||||
).prev_sample
|
||||
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
|
||||
|
||||
return sample
|
||||
|
||||
@@ -263,15 +247,11 @@ class DiffusionModel(nn.Module):
|
||||
if self.config.image_features:
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
|
||||
images_per_camera = einops.rearrange(
|
||||
batch["observation.images"], "b s n ... -> n (b s) ..."
|
||||
)
|
||||
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
|
||||
img_features_list = torch.cat(
|
||||
[
|
||||
encoder(images)
|
||||
for encoder, images in zip(
|
||||
self.rgb_encoder, images_per_camera, strict=True
|
||||
)
|
||||
for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True)
|
||||
]
|
||||
)
|
||||
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the
|
||||
@@ -285,9 +265,7 @@ class DiffusionModel(nn.Module):
|
||||
else:
|
||||
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
|
||||
img_features = self.rgb_encoder(
|
||||
einops.rearrange(
|
||||
batch["observation.images"], "b s n ... -> (b s n) ..."
|
||||
)
|
||||
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
||||
)
|
||||
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
@@ -381,9 +359,7 @@ class DiffusionModel(nn.Module):
|
||||
elif self.config.prediction_type == "sample":
|
||||
target = batch["action"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported prediction type {self.config.prediction_type}"
|
||||
)
|
||||
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
|
||||
|
||||
loss = F.mse_loss(pred, target, reduction="none")
|
||||
|
||||
@@ -443,9 +419,7 @@ class SpatialSoftmax(nn.Module):
|
||||
|
||||
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
|
||||
# and causes a small degradation in pc_success of pre-trained models.
|
||||
pos_x, pos_y = np.meshgrid(
|
||||
np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)
|
||||
)
|
||||
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
|
||||
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
|
||||
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
|
||||
# register as buffer so it's moved to the correct device.
|
||||
@@ -487,9 +461,7 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
# Always use center crop for eval
|
||||
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
|
||||
if config.crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(
|
||||
config.crop_shape
|
||||
)
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
|
||||
else:
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
@@ -510,9 +482,7 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
self.backbone = _replace_submodules(
|
||||
root_module=self.backbone,
|
||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||
func=lambda x: nn.GroupNorm(
|
||||
num_groups=x.num_features // 16, num_channels=x.num_features
|
||||
),
|
||||
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
||||
)
|
||||
|
||||
# Set up pooling and final layers.
|
||||
@@ -523,15 +493,11 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
|
||||
# Note: we have a check in the config class to make sure all images have the same shape.
|
||||
images_shape = next(iter(config.image_features.values())).shape
|
||||
dummy_shape_h_w = (
|
||||
config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
)
|
||||
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
|
||||
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
|
||||
|
||||
self.pool = SpatialSoftmax(
|
||||
feature_map_shape, num_kp=config.spatial_softmax_num_keypoints
|
||||
)
|
||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
self.relu = nn.ReLU()
|
||||
@@ -573,11 +539,7 @@ def _replace_submodules(
|
||||
if predicate(root_module):
|
||||
return func(root_module)
|
||||
|
||||
replace_list = [
|
||||
k.split(".")
|
||||
for k, m in root_module.named_modules(remove_duplicate=True)
|
||||
if predicate(m)
|
||||
]
|
||||
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
||||
for *parents, k in replace_list:
|
||||
parent_module = root_module
|
||||
if len(parents) > 0:
|
||||
@@ -592,9 +554,7 @@ def _replace_submodules(
|
||||
else:
|
||||
setattr(parent_module, k, tgt_module)
|
||||
# verify that all BN are replaced
|
||||
assert not any(
|
||||
predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)
|
||||
)
|
||||
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
|
||||
return root_module
|
||||
|
||||
|
||||
@@ -622,9 +582,7 @@ class DiffusionConv1dBlock(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv1d(
|
||||
inp_channels, out_channels, kernel_size, padding=kernel_size // 2
|
||||
),
|
||||
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
|
||||
nn.GroupNorm(n_groups, out_channels),
|
||||
nn.Mish(),
|
||||
)
|
||||
@@ -647,13 +605,9 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
# Encoder for the diffusion timestep.
|
||||
self.diffusion_step_encoder = nn.Sequential(
|
||||
DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim),
|
||||
nn.Linear(
|
||||
config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4
|
||||
),
|
||||
nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4),
|
||||
nn.Mish(),
|
||||
nn.Linear(
|
||||
config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim
|
||||
),
|
||||
nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim),
|
||||
)
|
||||
|
||||
# The FiLM conditioning dimension.
|
||||
@@ -678,16 +632,10 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
self.down_modules.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
dim_in, dim_out, **common_res_block_kwargs
|
||||
),
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
dim_out, dim_out, **common_res_block_kwargs
|
||||
),
|
||||
DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs),
|
||||
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
|
||||
# Downsample as long as it is not the last block.
|
||||
nn.Conv1d(dim_out, dim_out, 3, 2, 1)
|
||||
if not is_last
|
||||
else nn.Identity(),
|
||||
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
@@ -716,24 +664,16 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
nn.ModuleList(
|
||||
[
|
||||
# dim_in * 2, because it takes the encoder's skip connection as well
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
dim_in * 2, dim_out, **common_res_block_kwargs
|
||||
),
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
dim_out, dim_out, **common_res_block_kwargs
|
||||
),
|
||||
DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs),
|
||||
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
|
||||
# Upsample as long as it is not the last block.
|
||||
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1)
|
||||
if not is_last
|
||||
else nn.Identity(),
|
||||
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
DiffusionConv1dBlock(
|
||||
config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size
|
||||
),
|
||||
DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
|
||||
nn.Conv1d(config.down_dims[0], config.action_feature.shape[0], 1),
|
||||
)
|
||||
|
||||
@@ -801,23 +741,17 @@ class DiffusionConditionalResidualBlock1d(nn.Module):
|
||||
self.use_film_scale_modulation = use_film_scale_modulation
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.conv1 = DiffusionConv1dBlock(
|
||||
in_channels, out_channels, kernel_size, n_groups=n_groups
|
||||
)
|
||||
self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||
|
||||
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
|
||||
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
|
||||
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
|
||||
|
||||
self.conv2 = DiffusionConv1dBlock(
|
||||
out_channels, out_channels, kernel_size, n_groups=n_groups
|
||||
)
|
||||
self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||
|
||||
# A final convolution for dimension matching the residual (if needed).
|
||||
self.residual_conv = (
|
||||
nn.Conv1d(in_channels, out_channels, 1)
|
||||
if in_channels != out_channels
|
||||
else nn.Identity()
|
||||
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
|
||||
|
||||
@@ -104,9 +104,7 @@ def make_policy(
|
||||
PreTrainedPolicy: _description_
|
||||
"""
|
||||
if bool(ds_meta) == bool(env_cfg):
|
||||
raise ValueError(
|
||||
"Either one of a dataset metadata or a sim env must be provided."
|
||||
)
|
||||
raise ValueError("Either one of a dataset metadata or a sim env must be provided.")
|
||||
|
||||
# NOTE: Currently, if you try to run vqbet with mps backend, you'll get this error.
|
||||
# TODO(aliberts, rcadene): Implement a check_backend_compatibility in policies?
|
||||
@@ -136,12 +134,8 @@ def make_policy(
|
||||
)
|
||||
features = env_to_policy_features(env_cfg)
|
||||
|
||||
cfg.output_features = {
|
||||
key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION
|
||||
}
|
||||
cfg.input_features = {
|
||||
key: ft for key, ft in features.items() if key not in cfg.output_features
|
||||
}
|
||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||
kwargs["config"] = cfg
|
||||
|
||||
if cfg.pretrained_path:
|
||||
|
||||
@@ -7,9 +7,7 @@ from torch import Tensor, nn
|
||||
|
||||
from .configuration_classifier import ClassifierConfig
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -53,9 +51,7 @@ class Classifier(
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
encoder = AutoModel.from_pretrained(
|
||||
self.config.model_name, trust_remote_code=True
|
||||
)
|
||||
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
# Extract vision model if we're given a multimodal model
|
||||
if hasattr(encoder, "vision_model"):
|
||||
logging.info("Multimodal model detected - using vision encoder only")
|
||||
@@ -81,9 +77,7 @@ class Classifier(
|
||||
self.feature_dim = self.encoder.fc.in_features
|
||||
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
|
||||
elif hasattr(self.encoder.config, "hidden_sizes"):
|
||||
self.feature_dim = self.encoder.config.hidden_sizes[
|
||||
-1
|
||||
] # Last channel dimension
|
||||
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
|
||||
else:
|
||||
raise ValueError("Unsupported CNN architecture")
|
||||
|
||||
@@ -103,9 +97,7 @@ class Classifier(
|
||||
if hasattr(self.encoder.config, "hidden_size"):
|
||||
input_dim = self.encoder.config.hidden_size
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported transformer architecture since hidden_size is not found"
|
||||
)
|
||||
raise ValueError("Unsupported transformer architecture since hidden_size is not found")
|
||||
|
||||
self.classifier_head = nn.Sequential(
|
||||
nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim),
|
||||
@@ -141,10 +133,7 @@ class Classifier(
|
||||
return features
|
||||
else: # Transformer models
|
||||
outputs = self.encoder(processed)
|
||||
if (
|
||||
hasattr(outputs, "pooler_output")
|
||||
and outputs.pooler_output is not None
|
||||
):
|
||||
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
|
||||
return outputs.pooler_output
|
||||
return outputs.last_hidden_state[:, 0, :]
|
||||
|
||||
@@ -160,9 +149,7 @@ class Classifier(
|
||||
else:
|
||||
probabilities = torch.softmax(logits, dim=-1)
|
||||
|
||||
return ClassifierOutput(
|
||||
logits=logits, probabilities=probabilities, hidden_states=encoder_outputs
|
||||
)
|
||||
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
|
||||
|
||||
def predict_reward(self, x, threshold=0.6):
|
||||
if self.config.num_classes == 2:
|
||||
|
||||
@@ -82,43 +82,25 @@ def create_stats_buffers(
|
||||
if stats:
|
||||
if isinstance(stats[key]["mean"], np.ndarray):
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(
|
||||
dtype=torch.float32
|
||||
)
|
||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(
|
||||
dtype=torch.float32
|
||||
)
|
||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(
|
||||
dtype=torch.float32
|
||||
)
|
||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(
|
||||
dtype=torch.float32
|
||||
)
|
||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
||||
elif isinstance(stats[key]["mean"], torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = (
|
||||
stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||
)
|
||||
buffer["std"].data = (
|
||||
stats[key]["std"].clone().to(dtype=torch.float32)
|
||||
)
|
||||
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = (
|
||||
stats[key]["min"].clone().to(dtype=torch.float32)
|
||||
)
|
||||
buffer["max"].data = (
|
||||
stats[key]["max"].clone().to(dtype=torch.float32)
|
||||
)
|
||||
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
||||
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
||||
else:
|
||||
type_ = type(stats[key]["mean"])
|
||||
raise ValueError(
|
||||
f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead."
|
||||
)
|
||||
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
||||
|
||||
stats_buffers[key] = buffer
|
||||
return stats_buffers
|
||||
|
||||
@@ -44,9 +44,7 @@ def main():
|
||||
else:
|
||||
dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human"
|
||||
|
||||
ckpt_torch_dir = (
|
||||
Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
|
||||
)
|
||||
ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
|
||||
ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}"
|
||||
save_dir = Path(f"../openpi/data/{model_name}/save")
|
||||
|
||||
@@ -72,9 +70,7 @@ def main():
|
||||
# Create LeRobot batch from Jax
|
||||
batch = {}
|
||||
for cam_key, uint_chw_array in example["images"].items():
|
||||
batch[f"observation.images.{cam_key}"] = (
|
||||
torch.from_numpy(uint_chw_array) / 255.0
|
||||
)
|
||||
batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
|
||||
batch["observation.state"] = torch.from_numpy(example["state"])
|
||||
batch["action"] = torch.from_numpy(outputs["actions"])
|
||||
batch["task"] = example["prompt"]
|
||||
|
||||
@@ -54,9 +54,7 @@ def get_paligemma_config(precision: str):
|
||||
"projector_hidden_act": "gelu_fast",
|
||||
"vision_use_head": False,
|
||||
}
|
||||
final_config = PaliGemmaConfig(
|
||||
text_config=text_config, vision_config=vision_config, **config
|
||||
)
|
||||
final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config)
|
||||
return final_config
|
||||
|
||||
|
||||
|
||||
@@ -322,9 +322,7 @@ def update_keys_with_prefix(d: dict, prefix: str) -> dict:
|
||||
return {f"{prefix}{key}": value for key, value in d.items()}
|
||||
|
||||
|
||||
def convert_pi0_checkpoint(
|
||||
checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str
|
||||
):
|
||||
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str):
|
||||
# Break down orbax ckpts - they are in OCDBT
|
||||
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
|
||||
# process projection params
|
||||
@@ -384,9 +382,7 @@ def convert_pi0_checkpoint(
|
||||
# gemma_config=gemma_config, paligemma_config=paligemma_config)
|
||||
pi0_model = PI0Policy(pi0_config)
|
||||
|
||||
paligemma_params = update_keys_with_prefix(
|
||||
paligemma_params, "model.paligemma_with_expert."
|
||||
)
|
||||
paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.")
|
||||
gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.")
|
||||
projection_params = update_keys_with_prefix(projection_params, "model.")
|
||||
|
||||
|
||||
@@ -193,9 +193,7 @@ def aloha_gripper_to_angular(value):
|
||||
|
||||
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
||||
def linear_to_radian(linear_position, arm_length, horn_radius):
|
||||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (
|
||||
2 * horn_radius * linear_position
|
||||
)
|
||||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
|
||||
return safe_arcsin(value)
|
||||
|
||||
# The constants are taken from the Interbotix code.
|
||||
@@ -246,9 +244,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
@@ -256,9 +252,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google/paligemma-3b-pt-224"
|
||||
)
|
||||
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
self.model = PI0FlowMatching(config)
|
||||
|
||||
self.reset()
|
||||
@@ -271,9 +265,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
return self.parameters()
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(
|
||||
self, batch: dict[str, Tensor], noise: Tensor | None = None
|
||||
) -> Tensor:
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
@@ -312,9 +304,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(
|
||||
self, batch: dict[str, Tensor], noise=None, time=None
|
||||
) -> tuple[Tensor, dict[str, Tensor]]:
|
||||
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
|
||||
"""Do a full training forward pass to compute the loss"""
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
@@ -330,9 +320,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
actions_is_pad = batch.get("action_is_pad")
|
||||
|
||||
loss_dict = {}
|
||||
losses = self.model.forward(
|
||||
images, img_masks, lang_tokens, lang_masks, state, actions, noise, time
|
||||
)
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||
loss_dict["losses_after_forward"] = losses.clone()
|
||||
|
||||
if actions_is_pad is not None:
|
||||
@@ -359,9 +347,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
img_masks = []
|
||||
|
||||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||
missing_img_keys = [
|
||||
key for key in self.config.image_features if key not in batch
|
||||
]
|
||||
missing_img_keys = [key for key in self.config.image_features if key not in batch]
|
||||
|
||||
if len(present_img_keys) == 0:
|
||||
raise ValueError(
|
||||
@@ -373,9 +359,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
img = batch[key]
|
||||
|
||||
if self.config.resize_imgs_with_padding is not None:
|
||||
img = resize_with_pad(
|
||||
img, *self.config.resize_imgs_with_padding, pad_value=0
|
||||
)
|
||||
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
|
||||
|
||||
# Normalize from range [0,1] to [-1,1] as expacted by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
@@ -414,9 +398,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
return_tensors="pt",
|
||||
)
|
||||
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
||||
lang_masks = tokenized_prompt["attention_mask"].to(
|
||||
device=device, dtype=torch.bool
|
||||
)
|
||||
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
|
||||
|
||||
return lang_tokens, lang_masks
|
||||
|
||||
@@ -435,9 +417,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular(
|
||||
actions[:, :, motor_idx]
|
||||
)
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
|
||||
return actions
|
||||
|
||||
def _pi_aloha_encode_actions_inv(self, actions):
|
||||
@@ -446,9 +426,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(
|
||||
actions[:, :, motor_idx]
|
||||
)
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
|
||||
return actions
|
||||
|
||||
def prepare_state(self, batch):
|
||||
@@ -498,25 +476,15 @@ class PI0FlowMatching(nn.Module):
|
||||
train_expert_only=self.config.train_expert_only,
|
||||
attention_implementation=self.config.attention_implementation,
|
||||
)
|
||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
||||
paligemma_with_export_config
|
||||
)
|
||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
|
||||
|
||||
# Projections are float32
|
||||
self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
|
||||
self.action_in_proj = nn.Linear(
|
||||
self.config.max_action_dim, self.config.proj_width
|
||||
)
|
||||
self.action_out_proj = nn.Linear(
|
||||
self.config.proj_width, self.config.max_action_dim
|
||||
)
|
||||
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width)
|
||||
self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim)
|
||||
|
||||
self.action_time_mlp_in = nn.Linear(
|
||||
self.config.proj_width * 2, self.config.proj_width
|
||||
)
|
||||
self.action_time_mlp_out = nn.Linear(
|
||||
self.config.proj_width, self.config.proj_width
|
||||
)
|
||||
self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width)
|
||||
self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
|
||||
|
||||
self.set_requires_grad()
|
||||
|
||||
@@ -560,9 +528,7 @@ class PI0FlowMatching(nn.Module):
|
||||
|
||||
# Normalize image embeddings
|
||||
img_emb_dim = img_emb.shape[-1]
|
||||
img_emb = img_emb * torch.tensor(
|
||||
img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device
|
||||
)
|
||||
img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
|
||||
|
||||
bsize, num_img_embs = img_emb.shape[:2]
|
||||
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
|
||||
@@ -637,9 +603,7 @@ class PI0FlowMatching(nn.Module):
|
||||
embs.append(action_time_emb)
|
||||
|
||||
bsize, action_time_dim = action_time_emb.shape[:2]
|
||||
action_time_mask = torch.ones(
|
||||
bsize, action_time_dim, dtype=torch.bool, device=device
|
||||
)
|
||||
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
|
||||
pad_masks.append(action_time_mask)
|
||||
|
||||
# Set attention masks so that image, language and state inputs do not attend to action tokens
|
||||
@@ -677,9 +641,7 @@ class PI0FlowMatching(nn.Module):
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks
|
||||
)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
|
||||
state, x_t, time
|
||||
)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
|
||||
|
||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
||||
@@ -703,9 +665,7 @@ class PI0FlowMatching(nn.Module):
|
||||
losses = F.mse_loss(u_t, v_t, reduction="none")
|
||||
return losses
|
||||
|
||||
def sample_actions(
|
||||
self, images, img_masks, lang_tokens, lang_masks, state, noise=None
|
||||
) -> Tensor:
|
||||
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
|
||||
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
||||
bsize = state.shape[0]
|
||||
device = state.device
|
||||
@@ -763,16 +723,12 @@ class PI0FlowMatching(nn.Module):
|
||||
timestep,
|
||||
):
|
||||
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
|
||||
state, x_t, timestep
|
||||
)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep)
|
||||
|
||||
suffix_len = suffix_pad_masks.shape[1]
|
||||
batch_size = prefix_pad_masks.shape[0]
|
||||
prefix_len = prefix_pad_masks.shape[1]
|
||||
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(
|
||||
batch_size, suffix_len, prefix_len
|
||||
)
|
||||
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
|
||||
|
||||
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
||||
|
||||
|
||||
@@ -39,13 +39,9 @@ def apply_rope(x, positions, max_wavelength=10_000):
|
||||
dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
|
||||
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(
|
||||
d_half, dtype=torch.float32, device=device
|
||||
)
|
||||
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
|
||||
timescale = max_wavelength**freq_exponents
|
||||
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(
|
||||
torch.float32
|
||||
)
|
||||
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
|
||||
|
||||
radians = radians[..., None, :]
|
||||
|
||||
@@ -178,9 +174,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
def __init__(self, config: PaliGemmaWithExpertConfig):
|
||||
super().__init__(config=config)
|
||||
self.config = config
|
||||
self.paligemma = PaliGemmaForConditionalGeneration(
|
||||
config=config.paligemma_config
|
||||
)
|
||||
self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
|
||||
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
|
||||
# Remove unused embed_tokens
|
||||
self.gemma_expert.model.embed_tokens = None
|
||||
@@ -297,9 +291,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
|
||||
# the max len, then we (for instance) double the cache size. This implementation already exists
|
||||
# in `transformers`. (molbap)
|
||||
key_states = torch.cat(
|
||||
[past_key_values[layer_idx]["key_states"], key_states], dim=1
|
||||
)
|
||||
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
|
||||
value_states = torch.cat(
|
||||
[past_key_values[layer_idx]["value_states"], value_states],
|
||||
dim=1,
|
||||
@@ -392,9 +384,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
value_states,
|
||||
):
|
||||
num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
|
||||
num_key_value_heads = (
|
||||
self.config.paligemma_config.text_config.num_key_value_heads
|
||||
)
|
||||
num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
|
||||
num_key_value_groups = num_att_heads // num_key_value_heads
|
||||
|
||||
# query_states: batch_size, sequence_length, num_att_head, head_dim
|
||||
@@ -442,9 +432,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
att_weights *= head_dim**-0.5
|
||||
big_neg = -2.3819763e38 # See gemma/modules.py
|
||||
|
||||
masked_att_weights = torch.where(
|
||||
attention_mask[:, None, :, :], att_weights, big_neg
|
||||
)
|
||||
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
|
||||
|
||||
probs = nn.functional.softmax(masked_att_weights, dim=-1)
|
||||
probs = probs.to(dtype=value_states.dtype)
|
||||
@@ -456,8 +444,6 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
|
||||
att_output = att_output.permute(0, 2, 1, 3)
|
||||
# we use -1 because sequence length can change
|
||||
att_output = att_output.reshape(
|
||||
batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim
|
||||
)
|
||||
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
|
||||
|
||||
return att_output
|
||||
|
||||
@@ -71,9 +71,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
self.config._save_pretrained(save_directory)
|
||||
model_to_save = self.module if hasattr(self, "module") else self
|
||||
save_model_as_safetensor(
|
||||
model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE)
|
||||
)
|
||||
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
@@ -112,9 +110,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
if os.path.isdir(model_id):
|
||||
print("Loading weights from local directory")
|
||||
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
||||
policy = cls._load_as_safetensor(
|
||||
instance, model_file, config.device, strict
|
||||
)
|
||||
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
|
||||
else:
|
||||
try:
|
||||
model_file = hf_hub_download(
|
||||
@@ -128,9 +124,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
policy = cls._load_as_safetensor(
|
||||
instance, model_file, config.device, strict
|
||||
)
|
||||
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
|
||||
@@ -141,12 +135,8 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
return policy
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(
|
||||
cls, model: T, model_file: str, map_location: str, strict: bool
|
||||
) -> T:
|
||||
if packaging.version.parse(safetensors.__version__) < packaging.version.parse(
|
||||
"0.4.3"
|
||||
):
|
||||
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
||||
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"):
|
||||
load_model_as_safetensor(model, model_file, strict=strict)
|
||||
if map_location != "cpu":
|
||||
logging.warning(
|
||||
@@ -157,9 +147,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
)
|
||||
model.to(map_location)
|
||||
else:
|
||||
safetensors.torch.load_model(
|
||||
model, model_file, strict=strict, device=map_location
|
||||
)
|
||||
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
|
||||
return model
|
||||
|
||||
# def generate_model_card(self, *args, **kwargs) -> ModelCard:
|
||||
|
||||
@@ -48,9 +48,7 @@ class SACConfig:
|
||||
"observation.state": {"min": [-1, -1, -1, -1], "max": [1, 1, 1, 1]},
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {"action": "min_max"}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
|
||||
output_normalization_params: dict[str, dict[str, list[float]]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": {"min": [-1, -1], "max": [1, 1]},
|
||||
|
||||
@@ -18,8 +18,8 @@
|
||||
# TODO: (1) better device management
|
||||
|
||||
import math
|
||||
from typing import Callable, Optional, Tuple, Union, Dict, List
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
@@ -124,17 +124,13 @@ class SACPolicy(
|
||||
|
||||
self.actor = Policy(
|
||||
encoder=encoder_actor,
|
||||
network=MLP(
|
||||
input_dim=encoder_actor.output_dim, **config.actor_network_kwargs
|
||||
),
|
||||
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
|
||||
action_dim=config.output_shapes["action"][0],
|
||||
encoder_is_shared=config.shared_encoder,
|
||||
**config.policy_kwargs,
|
||||
)
|
||||
if config.target_entropy is None:
|
||||
config.target_entropy = (
|
||||
-np.prod(config.output_shapes["action"][0]) / 2
|
||||
) # (-dim(A)/2)
|
||||
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
|
||||
|
||||
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
||||
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
|
||||
@@ -146,10 +142,11 @@ class SACPolicy(
|
||||
|
||||
def _save_pretrained(self, save_directory):
|
||||
"""Custom save method to handle TensorDict properly"""
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
from dataclasses import asdict
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
|
||||
|
||||
from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE
|
||||
from safetensors.torch import save_model
|
||||
|
||||
save_model(self, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE))
|
||||
@@ -177,12 +174,14 @@ class SACPolicy(
|
||||
**model_kwargs,
|
||||
) -> "SACPolicy":
|
||||
"""Custom load method to handle loading SAC policy from saved files"""
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
|
||||
from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE
|
||||
from safetensors.torch import load_model
|
||||
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
|
||||
# Check if model_id is a local path or a hub model ID
|
||||
@@ -302,14 +301,10 @@ class SACPolicy(
|
||||
) -> Tensor:
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
with torch.no_grad():
|
||||
next_action_preds, next_log_probs, _ = self.actor(
|
||||
next_observations, next_observation_features
|
||||
)
|
||||
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
|
||||
|
||||
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
|
||||
next_action_preds = self.unnormalize_outputs({"action": next_action_preds})[
|
||||
"action"
|
||||
]
|
||||
next_action_preds = self.unnormalize_outputs({"action": next_action_preds})["action"]
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self.critic_forward(
|
||||
@@ -353,21 +348,15 @@ class SACPolicy(
|
||||
).sum()
|
||||
return critics_loss
|
||||
|
||||
def compute_loss_temperature(
|
||||
self, observations, observation_features: Tensor | None = None
|
||||
) -> Tensor:
|
||||
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
||||
"""Compute the temperature loss"""
|
||||
# calculate temperature loss
|
||||
with torch.no_grad():
|
||||
_, log_probs, _ = self.actor(observations, observation_features)
|
||||
temperature_loss = (
|
||||
-self.log_alpha.exp() * (log_probs + self.config.target_entropy)
|
||||
).mean()
|
||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
|
||||
return temperature_loss
|
||||
|
||||
def compute_loss_actor(
|
||||
self, observations, observation_features: Tensor | None = None
|
||||
) -> Tensor:
|
||||
def compute_loss_actor(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
||||
@@ -408,11 +397,7 @@ class MLP(nn.Module):
|
||||
if dropout_rate is not None and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(hidden_dims[0]))
|
||||
layers.append(
|
||||
activations
|
||||
if isinstance(activations, nn.Module)
|
||||
else getattr(nn, activations)()
|
||||
)
|
||||
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
|
||||
|
||||
# Rest of the layers
|
||||
for i in range(1, len(hidden_dims)):
|
||||
@@ -424,11 +409,7 @@ class MLP(nn.Module):
|
||||
layers.append(nn.LayerNorm(hidden_dims[i]))
|
||||
|
||||
# If we're at the final layer and a final activation is specified, use it
|
||||
if (
|
||||
i + 1 == len(hidden_dims)
|
||||
and activate_final
|
||||
and final_activation is not None
|
||||
):
|
||||
if i + 1 == len(hidden_dims) and activate_final and final_activation is not None:
|
||||
layers.append(
|
||||
final_activation
|
||||
if isinstance(final_activation, nn.Module)
|
||||
@@ -436,9 +417,7 @@ class MLP(nn.Module):
|
||||
)
|
||||
else:
|
||||
layers.append(
|
||||
activations
|
||||
if isinstance(activations, nn.Module)
|
||||
else getattr(nn, activations)()
|
||||
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
|
||||
)
|
||||
|
||||
self.net = nn.Sequential(*layers)
|
||||
@@ -639,15 +618,11 @@ class Policy(nn.Module):
|
||||
# Compute standard deviations
|
||||
if self.fixed_std is None:
|
||||
log_std = self.std_layer(outputs)
|
||||
assert not torch.isnan(log_std).any(), (
|
||||
"[ERROR] log_std became NaN after std_layer!"
|
||||
)
|
||||
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
|
||||
|
||||
if self.use_tanh_squash:
|
||||
log_std = torch.tanh(log_std)
|
||||
log_std = self.log_std_min + 0.5 * (
|
||||
self.log_std_max - self.log_std_min
|
||||
) * (log_std + 1.0)
|
||||
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
|
||||
else:
|
||||
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
|
||||
else:
|
||||
@@ -660,9 +635,7 @@ class Policy(nn.Module):
|
||||
|
||||
if self.use_tanh_squash:
|
||||
actions = torch.tanh(x_t)
|
||||
log_probs -= torch.log(
|
||||
(1 - actions.pow(2)) + 1e-6
|
||||
) # Adjust log-probs for Tanh
|
||||
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
|
||||
else:
|
||||
actions = x_t # No Tanh; raw Gaussian sample
|
||||
|
||||
@@ -709,9 +682,7 @@ class SACObservationEncoder(nn.Module):
|
||||
freeze_image_encoder(self.image_enc_layers)
|
||||
else:
|
||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||
self.all_image_keys = [
|
||||
k for k in config.input_shapes if k.startswith("observation.image")
|
||||
]
|
||||
self.all_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
|
||||
if "observation.state" in config.input_shapes:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
@@ -738,9 +709,7 @@ class SACObservationEncoder(nn.Module):
|
||||
self.aggregation_size += config.latent_dim
|
||||
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
|
||||
|
||||
self.aggregation_layer = nn.Linear(
|
||||
in_features=self.aggregation_size, out_features=config.latent_dim
|
||||
)
|
||||
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
||||
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
|
||||
|
||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||
@@ -753,19 +722,13 @@ class SACObservationEncoder(nn.Module):
|
||||
obs_dict = self.input_normalization(obs_dict)
|
||||
# Batch all images along the batch dimension, then encode them.
|
||||
if len(self.all_image_keys) > 0:
|
||||
images_batched = torch.cat(
|
||||
[obs_dict[key] for key in self.all_image_keys], dim=0
|
||||
)
|
||||
images_batched = torch.cat([obs_dict[key] for key in self.all_image_keys], dim=0)
|
||||
images_batched = self.image_enc_layers(images_batched)
|
||||
embeddings_chunks = torch.chunk(
|
||||
images_batched, dim=0, chunks=len(self.all_image_keys)
|
||||
)
|
||||
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
|
||||
feat.extend(embeddings_chunks)
|
||||
|
||||
if "observation.environment_state" in self.config.input_shapes:
|
||||
feat.append(
|
||||
self.env_state_enc_layers(obs_dict["observation.environment_state"])
|
||||
)
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
|
||||
@@ -833,9 +796,7 @@ class PretrainedImageEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.image_enc_layers, self.image_enc_out_shape = (
|
||||
self._load_pretrained_vision_encoder(config)
|
||||
)
|
||||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
||||
self.image_enc_proj = nn.Sequential(
|
||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
@@ -846,21 +807,15 @@ class PretrainedImageEncoder(nn.Module):
|
||||
"""Set up CNN encoder"""
|
||||
from transformers import AutoModel
|
||||
|
||||
self.image_enc_layers = AutoModel.from_pretrained(
|
||||
config.vision_encoder_name, trust_remote_code=True
|
||||
)
|
||||
self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True)
|
||||
# self.image_enc_layers.pooler = Identity()
|
||||
|
||||
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
|
||||
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[
|
||||
-1
|
||||
] # Last channel dimension
|
||||
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
|
||||
elif hasattr(self.image_enc_layers, "fc"):
|
||||
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported vision encoder architecture, make sure you are using a CNN"
|
||||
)
|
||||
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN")
|
||||
return self.image_enc_layers, self.image_enc_out_shape
|
||||
|
||||
def forward(self, x):
|
||||
@@ -896,9 +851,7 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
||||
for key, value in inner_dict.items():
|
||||
converted_params[outer_key][key] = torch.tensor(value)
|
||||
if "image" in outer_key:
|
||||
converted_params[outer_key][key] = converted_params[outer_key][
|
||||
key
|
||||
].view(3, 1, 1)
|
||||
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
|
||||
|
||||
return converted_params
|
||||
|
||||
|
||||
@@ -183,13 +183,9 @@ class TDMPCConfig(PreTrainedConfig):
|
||||
"If `n_action_steps > 1`, `n_action_repeats` must be left to its default value of 1."
|
||||
)
|
||||
if not self.use_mpc:
|
||||
raise ValueError(
|
||||
"If `n_action_steps > 1`, `use_mpc` must be set to `True`."
|
||||
)
|
||||
raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.")
|
||||
if self.n_action_steps > self.horizon:
|
||||
raise ValueError(
|
||||
"`n_action_steps` must be less than or equal to `horizon`."
|
||||
)
|
||||
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")
|
||||
|
||||
def get_optimizer_preset(self) -> AdamConfig:
|
||||
return AdamConfig(lr=self.optimizer_lr)
|
||||
@@ -209,9 +205,7 @@ class TDMPCConfig(PreTrainedConfig):
|
||||
if image_ft.shape[-2] != image_ft.shape[-1]:
|
||||
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
||||
# augmentation. It should be able to be removed.
|
||||
raise ValueError(
|
||||
f"Only square images are handled now. Got image shape {image_ft.shape}."
|
||||
)
|
||||
raise ValueError(f"Only square images are handled now. Got image shape {image_ft.shape}.")
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
|
||||
@@ -83,9 +83,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
@@ -110,9 +108,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
self._queues = {
|
||||
"observation.state": deque(maxlen=1),
|
||||
"action": deque(
|
||||
maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)
|
||||
),
|
||||
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
|
||||
}
|
||||
if self.config.image_features:
|
||||
self._queues["observation.image"] = deque(maxlen=1)
|
||||
@@ -127,9 +123,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
"""Select a single action given environment observations."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[next(iter(self.config.image_features))]
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
@@ -232,47 +226,35 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
self.config.action_feature.shape[0],
|
||||
device=std.device,
|
||||
)
|
||||
gaussian_actions = torch.clamp(
|
||||
mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1
|
||||
)
|
||||
gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1)
|
||||
|
||||
# Compute elite actions.
|
||||
actions = torch.cat([gaussian_actions, pi_actions], dim=1)
|
||||
value = self.estimate_value(z, actions).nan_to_num_(0)
|
||||
elite_idxs = torch.topk(
|
||||
value, self.config.n_elites, dim=0
|
||||
).indices # (n_elites, batch)
|
||||
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices # (n_elites, batch)
|
||||
elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch)
|
||||
# (horizon, n_elites, batch, action_dim)
|
||||
elite_actions = actions.take_along_dim(
|
||||
einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1
|
||||
)
|
||||
elite_actions = actions.take_along_dim(einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1)
|
||||
|
||||
# Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites.
|
||||
max_value = elite_value.max(0, keepdim=True)[0] # (1, batch)
|
||||
# The weighting is a softmax over trajectory values. Note that this is not the same as the usage
|
||||
# of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This
|
||||
# makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²).
|
||||
score = torch.exp(
|
||||
self.config.elite_weighting_temperature * (elite_value - max_value)
|
||||
)
|
||||
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
|
||||
score /= score.sum(axis=0, keepdim=True)
|
||||
# (horizon, batch, action_dim)
|
||||
_mean = torch.sum(
|
||||
einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1
|
||||
)
|
||||
_mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1)
|
||||
_std = torch.sqrt(
|
||||
torch.sum(
|
||||
einops.rearrange(score, "n b -> n b 1")
|
||||
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d"))
|
||||
** 2,
|
||||
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2,
|
||||
dim=1,
|
||||
)
|
||||
)
|
||||
# Update mean with an exponential moving average, and std with a direct replacement.
|
||||
mean = (
|
||||
self.config.gaussian_mean_momentum * mean
|
||||
+ (1 - self.config.gaussian_mean_momentum) * _mean
|
||||
self.config.gaussian_mean_momentum * mean + (1 - self.config.gaussian_mean_momentum) * _mean
|
||||
)
|
||||
std = _std.clamp_(self.config.min_std, self.config.max_std)
|
||||
|
||||
@@ -281,9 +263,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
|
||||
# Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax
|
||||
# scores from the last iteration.
|
||||
actions = elite_actions[
|
||||
:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)
|
||||
]
|
||||
actions = elite_actions[:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)]
|
||||
|
||||
return actions
|
||||
|
||||
@@ -306,8 +286,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
# of the FOWM paper.
|
||||
if self.config.uncertainty_regularizer_coeff > 0:
|
||||
regularization = -(
|
||||
self.config.uncertainty_regularizer_coeff
|
||||
* self.model.Qs(z, actions[t]).std(0)
|
||||
self.config.uncertainty_regularizer_coeff * self.model.Qs(z, actions[t]).std(0)
|
||||
)
|
||||
else:
|
||||
regularization = 0
|
||||
@@ -328,9 +307,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
G += (
|
||||
running_discount
|
||||
* torch.min(
|
||||
terminal_values[
|
||||
torch.randint(0, self.config.q_ensemble_size, size=(2,))
|
||||
],
|
||||
terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))],
|
||||
dim=0,
|
||||
)[0]
|
||||
)
|
||||
@@ -338,11 +315,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
G += running_discount * torch.min(terminal_values, dim=0)[0]
|
||||
# Finally, also regularize the terminal value.
|
||||
if self.config.uncertainty_regularizer_coeff > 0:
|
||||
G -= (
|
||||
running_discount
|
||||
* self.config.uncertainty_regularizer_coeff
|
||||
* terminal_values.std(0)
|
||||
)
|
||||
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
|
||||
return G
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
@@ -354,9 +327,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[next(iter(self.config.image_features))]
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
@@ -388,29 +359,21 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
current_observation[k] = observations[k][0]
|
||||
next_observations[k] = observations[k][1:]
|
||||
horizon, batch_size = next_observations[
|
||||
"observation.image"
|
||||
if self.config.image_features
|
||||
else "observation.environment_state"
|
||||
"observation.image" if self.config.image_features else "observation.environment_state"
|
||||
].shape[:2]
|
||||
|
||||
# Run latent rollout using the latent dynamics model and policy model.
|
||||
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
|
||||
# gives us a next `z`.
|
||||
batch_size = batch["index"].shape[0]
|
||||
z_preds = torch.empty(
|
||||
horizon + 1, batch_size, self.config.latent_dim, device=device
|
||||
)
|
||||
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
|
||||
z_preds[0] = self.model.encode(current_observation)
|
||||
reward_preds = torch.empty_like(reward, device=device)
|
||||
for t in range(horizon):
|
||||
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(
|
||||
z_preds[t], action[t]
|
||||
)
|
||||
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t])
|
||||
|
||||
# Compute Q and V value predictions based on the latent rollout.
|
||||
q_preds_ensemble = self.model.Qs(
|
||||
z_preds[:-1], action
|
||||
) # (ensemble, horizon, batch)
|
||||
q_preds_ensemble = self.model.Qs(z_preds[:-1], action) # (ensemble, horizon, batch)
|
||||
v_preds = self.model.V(z_preds[:-1])
|
||||
info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()})
|
||||
|
||||
@@ -424,14 +387,10 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
# actions (not actions estimated by π).
|
||||
# Note: Here we do not use self.model_target, but self.model. This is to follow the original code
|
||||
# and the FOWM paper.
|
||||
q_targets = reward + self.config.discount * self.model.V(
|
||||
self.model.encode(next_observations)
|
||||
)
|
||||
q_targets = reward + self.config.discount * self.model.V(self.model.encode(next_observations))
|
||||
# From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we
|
||||
# are using them to compute loss for V.
|
||||
v_targets = self.model_target.Qs(
|
||||
z_preds[:-1].detach(), action, return_min=True
|
||||
)
|
||||
v_targets = self.model_target.Qs(z_preds[:-1].detach(), action, return_min=True)
|
||||
|
||||
# Compute losses.
|
||||
# Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the
|
||||
@@ -474,9 +433,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(
|
||||
q_preds_ensemble,
|
||||
einops.repeat(
|
||||
q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]
|
||||
),
|
||||
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
|
||||
reduction="none",
|
||||
).sum(0) # sum over ensemble
|
||||
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||
@@ -514,14 +471,12 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
z_preds = z_preds.detach()
|
||||
# Use stopgrad for the advantage calculation.
|
||||
with torch.no_grad():
|
||||
advantage = self.model_target.Qs(
|
||||
z_preds[:-1], action, return_min=True
|
||||
) - self.model.V(z_preds[:-1])
|
||||
advantage = self.model_target.Qs(z_preds[:-1], action, return_min=True) - self.model.V(
|
||||
z_preds[:-1]
|
||||
)
|
||||
info["advantage"] = advantage[0]
|
||||
# (t, b)
|
||||
exp_advantage = torch.clamp(
|
||||
torch.exp(advantage * self.config.advantage_scaling), max=100.0
|
||||
)
|
||||
exp_advantage = torch.clamp(torch.exp(advantage * self.config.advantage_scaling), max=100.0)
|
||||
action_preds = self.model.pi(z_preds[:-1]) # (t, b, a)
|
||||
# Calculate the MSE between the actions and the action predictions.
|
||||
# Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation
|
||||
@@ -575,9 +530,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
# Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA
|
||||
# update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code
|
||||
# we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995)
|
||||
update_ema_parameters(
|
||||
self.model_target, self.model, self.config.target_model_momentum
|
||||
)
|
||||
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
|
||||
|
||||
|
||||
class TDMPCTOLD(nn.Module):
|
||||
@@ -588,9 +541,7 @@ class TDMPCTOLD(nn.Module):
|
||||
self.config = config
|
||||
self._encoder = TDMPCObservationEncoder(config)
|
||||
self._dynamics = nn.Sequential(
|
||||
nn.Linear(
|
||||
config.latent_dim + config.action_feature.shape[0], config.mlp_dim
|
||||
),
|
||||
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
@@ -601,9 +552,7 @@ class TDMPCTOLD(nn.Module):
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
self._reward = nn.Sequential(
|
||||
nn.Linear(
|
||||
config.latent_dim + config.action_feature.shape[0], config.mlp_dim
|
||||
),
|
||||
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
@@ -671,9 +620,7 @@ class TDMPCTOLD(nn.Module):
|
||||
"Sanity check. The last linear layer needs 0 initialization on weights."
|
||||
)
|
||||
nn.init.zeros_(m[-1].weight)
|
||||
nn.init.zeros_(
|
||||
m[-1].bias
|
||||
) # this has already been done, but keep this line here for good measure
|
||||
nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure
|
||||
|
||||
def encode(self, obs: dict[str, Tensor]) -> Tensor:
|
||||
"""Encodes an observation into its latent representation."""
|
||||
@@ -812,9 +759,7 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
|
||||
if config.robot_state_feature:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
config.robot_state_feature.shape[0], config.state_encoder_hidden_dim
|
||||
),
|
||||
nn.Linear(config.robot_state_feature.shape[0], config.state_encoder_hidden_dim),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
@@ -823,9 +768,7 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
|
||||
if config.env_state_feature:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
config.env_state_feature.shape[0], config.state_encoder_hidden_dim
|
||||
),
|
||||
nn.Linear(config.env_state_feature.shape[0], config.state_encoder_hidden_dim),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
@@ -898,10 +841,7 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
|
||||
assert n_p_ema == n_p, "Parameter names don't match for EMA model update"
|
||||
if isinstance(p, dict):
|
||||
raise RuntimeError("Dict parameter not supported")
|
||||
if (
|
||||
isinstance(module, nn.modules.batchnorm._BatchNorm)
|
||||
or not p.requires_grad
|
||||
):
|
||||
if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad:
|
||||
# Copy BatchNorm parameters, and non-trainable parameters directly.
|
||||
p_ema.copy_(p.to(dtype=p_ema.dtype).data)
|
||||
with torch.no_grad():
|
||||
@@ -909,9 +849,7 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
|
||||
p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha)
|
||||
|
||||
|
||||
def flatten_forward_unflatten(
|
||||
fn: Callable[[Tensor], Tensor], image_tensor: Tensor
|
||||
) -> Tensor:
|
||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -172,10 +172,7 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
|
||||
if self.crop_shape is not None:
|
||||
for key, image_ft in self.image_features.items():
|
||||
if (
|
||||
self.crop_shape[0] > image_ft.shape[1]
|
||||
or self.crop_shape[1] > image_ft.shape[2]
|
||||
):
|
||||
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for "
|
||||
|
||||
@@ -64,9 +64,7 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
@@ -97,17 +95,11 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
if self.config.sequentially_select:
|
||||
decay_params = (
|
||||
decay_params
|
||||
+ list(
|
||||
self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters()
|
||||
)
|
||||
+ list(
|
||||
self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters()
|
||||
)
|
||||
+ list(self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
|
||||
+ list(self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
|
||||
)
|
||||
else:
|
||||
decay_params = decay_params + list(
|
||||
self.vqbet.action_head.map_to_cbet_preds_bin.parameters()
|
||||
)
|
||||
decay_params = decay_params + list(self.vqbet.action_head.map_to_cbet_preds_bin.parameters())
|
||||
|
||||
return [
|
||||
{
|
||||
@@ -145,12 +137,8 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
# Note: It's important that this happens after stacking the images into a single key.
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -161,14 +149,8 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
)
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
batch = {
|
||||
k: torch.stack(list(self._queues[k]), dim=1)
|
||||
for k in batch
|
||||
if k in self._queues
|
||||
}
|
||||
actions = self.vqbet(batch, rollout=True)[
|
||||
:, : self.config.action_chunk_size
|
||||
]
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
|
||||
|
||||
# the dimension of returned action is (batch_size, action_chunk_size, action_dim)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
@@ -181,12 +163,8 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
|
||||
if not self.vqbet.action_head.vqvae_model.discretized.item():
|
||||
@@ -194,9 +172,7 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
# n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`.
|
||||
# n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree).
|
||||
loss, n_different_codes, n_different_combinations, recon_l1_error = (
|
||||
self.vqbet.action_head.discretize(
|
||||
self.config.n_vqvae_training_steps, batch["action"]
|
||||
)
|
||||
self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"])
|
||||
)
|
||||
return loss, {
|
||||
"n_different_codes": n_different_codes,
|
||||
@@ -253,9 +229,7 @@ class SpatialSoftmax(nn.Module):
|
||||
|
||||
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
|
||||
# and causes a small degradation in pc_success of pre-trained models.
|
||||
pos_x, pos_y = np.meshgrid(
|
||||
np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)
|
||||
)
|
||||
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
|
||||
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
|
||||
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
|
||||
# register as buffer so it's moved to the correct device.
|
||||
@@ -370,12 +344,7 @@ class VQBeTModel(nn.Module):
|
||||
num_tokens = self.config.n_action_pred_token + self.config.n_obs_steps - 1
|
||||
self.register_buffer(
|
||||
"select_target_actions_indices",
|
||||
torch.row_stack(
|
||||
[
|
||||
torch.arange(i, i + self.config.action_chunk_size)
|
||||
for i in range(num_tokens)
|
||||
]
|
||||
),
|
||||
torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),
|
||||
)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]:
|
||||
@@ -406,19 +375,13 @@ class VQBeTModel(nn.Module):
|
||||
input_tokens.append(
|
||||
self.state_projector(batch["observation.state"])
|
||||
) # (batch, obs_step, projection dims)
|
||||
input_tokens.append(
|
||||
einops.repeat(
|
||||
self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps
|
||||
)
|
||||
)
|
||||
input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps))
|
||||
# Interleave tokens by stacking and rearranging.
|
||||
input_tokens = torch.stack(input_tokens, dim=2)
|
||||
input_tokens = einops.rearrange(input_tokens, "b n t d -> b (n t) d")
|
||||
|
||||
len_additional_action_token = self.config.n_action_pred_token - 1
|
||||
future_action_tokens = self.action_token.repeat(
|
||||
batch_size, len_additional_action_token, 1
|
||||
)
|
||||
future_action_tokens = self.action_token.repeat(batch_size, len_additional_action_token, 1)
|
||||
|
||||
# add additional action query tokens for predicting future action chunks
|
||||
input_tokens = torch.cat([input_tokens, future_action_tokens], dim=1)
|
||||
@@ -427,9 +390,9 @@ class VQBeTModel(nn.Module):
|
||||
features = self.policy(input_tokens)
|
||||
# len(self.config.input_features) is the number of different observation modes.
|
||||
# this line gets the index of action prompt tokens.
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * (
|
||||
len(self.config.input_features) + 1
|
||||
) + len(self.config.input_features)
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_features) + 1) + len(
|
||||
self.config.input_features
|
||||
)
|
||||
|
||||
# only extract the output tokens at the position of action query:
|
||||
# Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models,
|
||||
@@ -449,15 +412,13 @@ class VQBeTModel(nn.Module):
|
||||
action_head_output = self.action_head(features)
|
||||
# if rollout, VQ-BeT don't calculate loss
|
||||
if rollout:
|
||||
return action_head_output["predicted_action"][
|
||||
:, n_obs_steps - 1, :
|
||||
].reshape(batch_size, self.config.action_chunk_size, -1)
|
||||
return action_head_output["predicted_action"][:, n_obs_steps - 1, :].reshape(
|
||||
batch_size, self.config.action_chunk_size, -1
|
||||
)
|
||||
# else, it calculate overall loss (bin prediction loss, and offset loss)
|
||||
else:
|
||||
output = batch["action"][:, self.select_target_actions_indices]
|
||||
loss = self.action_head.loss_fn(
|
||||
action_head_output, output, reduction="mean"
|
||||
)
|
||||
loss = self.action_head.loss_fn(action_head_output, output, reduction="mean")
|
||||
return action_head_output, loss
|
||||
|
||||
|
||||
@@ -492,9 +453,7 @@ class VQBeTHead(nn.Module):
|
||||
else:
|
||||
self.map_to_cbet_preds_bin = MLP(
|
||||
in_channels=config.gpt_output_dim,
|
||||
hidden_channels=[
|
||||
self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed
|
||||
],
|
||||
hidden_channels=[self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed],
|
||||
)
|
||||
self.map_to_cbet_preds_offset = MLP(
|
||||
in_channels=config.gpt_output_dim,
|
||||
@@ -521,10 +480,7 @@ class VQBeTHead(nn.Module):
|
||||
|
||||
loss, metric = self.vqvae_model.vqvae_forward(actions)
|
||||
n_different_codes = sum(
|
||||
[
|
||||
len(torch.unique(metric[2][:, i]))
|
||||
for i in range(self.vqvae_model.vqvae_num_layers)
|
||||
]
|
||||
[len(torch.unique(metric[2][:, i])) for i in range(self.vqvae_model.vqvae_num_layers)]
|
||||
)
|
||||
n_different_combinations = len(torch.unique(metric[2], dim=0))
|
||||
recon_l1_error = metric[0].detach().cpu().item()
|
||||
@@ -585,18 +541,12 @@ class VQBeTHead(nn.Module):
|
||||
cbet_secondary_logits / self.config.bet_softmax_temperature, dim=-1
|
||||
)
|
||||
sampled_secondary_centers = einops.rearrange(
|
||||
torch.multinomial(
|
||||
cbet_secondary_probs.view(-1, choices), num_samples=1
|
||||
),
|
||||
torch.multinomial(cbet_secondary_probs.view(-1, choices), num_samples=1),
|
||||
"(NT) 1 -> NT",
|
||||
NT=NT,
|
||||
)
|
||||
sampled_centers = torch.stack(
|
||||
(sampled_primary_centers, sampled_secondary_centers), axis=1
|
||||
)
|
||||
cbet_logits = torch.stack(
|
||||
[cbet_primary_logits, cbet_secondary_logits], dim=1
|
||||
)
|
||||
sampled_centers = torch.stack((sampled_primary_centers, sampled_secondary_centers), axis=1)
|
||||
cbet_logits = torch.stack([cbet_primary_logits, cbet_secondary_logits], dim=1)
|
||||
# if self.config.sequentially_select is False, bin prediction head samples primary and secondary code at once.
|
||||
else:
|
||||
cbet_logits = self.map_to_cbet_preds_bin(x)
|
||||
@@ -605,9 +555,7 @@ class VQBeTHead(nn.Module):
|
||||
"(NT) (G C) -> (NT) G C",
|
||||
G=self.vqvae_model.vqvae_num_layers,
|
||||
)
|
||||
cbet_probs = torch.softmax(
|
||||
cbet_logits / self.config.bet_softmax_temperature, dim=-1
|
||||
)
|
||||
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
|
||||
NT, G, choices = cbet_probs.shape
|
||||
sampled_centers = einops.rearrange(
|
||||
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
|
||||
@@ -627,17 +575,9 @@ class VQBeTHead(nn.Module):
|
||||
sampled_offsets = sampled_offsets.sum(dim=1)
|
||||
with torch.no_grad():
|
||||
# Get the centroids (= vectors corresponding to the codes) of each layer to pass it through RVQ decoder
|
||||
return_decoder_input = (
|
||||
self.vqvae_model.get_embeddings_from_code(sampled_centers)
|
||||
.clone()
|
||||
.detach()
|
||||
)
|
||||
return_decoder_input = self.vqvae_model.get_embeddings_from_code(sampled_centers).clone().detach()
|
||||
# pass the centroids through decoder to get actions.
|
||||
decoded_action = (
|
||||
self.vqvae_model.get_action_from_latent(return_decoder_input)
|
||||
.clone()
|
||||
.detach()
|
||||
)
|
||||
decoded_action = self.vqvae_model.get_action_from_latent(return_decoder_input).clone().detach()
|
||||
# reshaped extracted offset to match with decoded centroids
|
||||
sampled_offsets = einops.rearrange(
|
||||
sampled_offsets, "NT (W A) -> NT W A", W=self.config.action_chunk_size
|
||||
@@ -686,9 +626,7 @@ class VQBeTHead(nn.Module):
|
||||
# Figure out the loss for the actions.
|
||||
# First, we need to find the closest cluster center for each ground truth action.
|
||||
with torch.no_grad():
|
||||
state_vq, action_bins = self.vqvae_model.get_code(
|
||||
action_seq
|
||||
) # action_bins: NT, G
|
||||
state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
|
||||
|
||||
# Now we can compute the loss.
|
||||
|
||||
@@ -711,12 +649,8 @@ class VQBeTHead(nn.Module):
|
||||
+ cbet_loss2 * self.config.secondary_code_loss_weight
|
||||
)
|
||||
|
||||
equal_primary_code_rate = torch.sum(
|
||||
(action_bins[:, 0] == sampled_centers[:, 0]).int()
|
||||
) / (NT)
|
||||
equal_secondary_code_rate = torch.sum(
|
||||
(action_bins[:, 1] == sampled_centers[:, 1]).int()
|
||||
) / (NT)
|
||||
equal_primary_code_rate = torch.sum((action_bins[:, 0] == sampled_centers[:, 0]).int()) / (NT)
|
||||
equal_secondary_code_rate = torch.sum((action_bins[:, 1] == sampled_centers[:, 1]).int()) / (NT)
|
||||
|
||||
action_mse_error = torch.mean((action_seq - predicted_action) ** 2)
|
||||
vq_action_error = torch.mean(torch.abs(action_seq - decoded_action))
|
||||
@@ -730,9 +664,7 @@ class VQBeTHead(nn.Module):
|
||||
"classification_loss": cbet_loss.detach().cpu().item(),
|
||||
"offset_loss": offset_loss.detach().cpu().item(),
|
||||
"equal_primary_code_rate": equal_primary_code_rate.detach().cpu().item(),
|
||||
"equal_secondary_code_rate": equal_secondary_code_rate.detach()
|
||||
.cpu()
|
||||
.item(),
|
||||
"equal_secondary_code_rate": equal_secondary_code_rate.detach().cpu().item(),
|
||||
"vq_action_error": vq_action_error.detach().cpu().item(),
|
||||
"offset_action_error": offset_action_error.detach().cpu().item(),
|
||||
"action_error_max": action_error_max.detach().cpu().item(),
|
||||
@@ -757,9 +689,7 @@ class VQBeTRgbEncoder(nn.Module):
|
||||
# Always use center crop for eval
|
||||
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
|
||||
if config.crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(
|
||||
config.crop_shape
|
||||
)
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
|
||||
else:
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
@@ -780,9 +710,7 @@ class VQBeTRgbEncoder(nn.Module):
|
||||
self.backbone = _replace_submodules(
|
||||
root_module=self.backbone,
|
||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||
func=lambda x: nn.GroupNorm(
|
||||
num_groups=x.num_features // 16, num_channels=x.num_features
|
||||
),
|
||||
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
||||
)
|
||||
|
||||
# Set up pooling and final layers.
|
||||
@@ -792,15 +720,11 @@ class VQBeTRgbEncoder(nn.Module):
|
||||
# height and width from `config.image_features`.
|
||||
|
||||
images_shape = next(iter(config.image_features.values())).shape
|
||||
dummy_shape_h_w = (
|
||||
config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
)
|
||||
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
|
||||
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
|
||||
|
||||
self.pool = SpatialSoftmax(
|
||||
feature_map_shape, num_kp=config.spatial_softmax_num_keypoints
|
||||
)
|
||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
self.relu = nn.ReLU()
|
||||
@@ -842,11 +766,7 @@ def _replace_submodules(
|
||||
if predicate(root_module):
|
||||
return func(root_module)
|
||||
|
||||
replace_list = [
|
||||
k.split(".")
|
||||
for k, m in root_module.named_modules(remove_duplicate=True)
|
||||
if predicate(m)
|
||||
]
|
||||
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
||||
for *parents, k in replace_list:
|
||||
parent_module = root_module
|
||||
if len(parents) > 0:
|
||||
@@ -861,9 +781,7 @@ def _replace_submodules(
|
||||
else:
|
||||
setattr(parent_module, k, tgt_module)
|
||||
# verify that all BN are replaced
|
||||
assert not any(
|
||||
predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)
|
||||
)
|
||||
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
|
||||
return root_module
|
||||
|
||||
|
||||
@@ -896,8 +814,7 @@ class VqVae(nn.Module):
|
||||
)
|
||||
|
||||
self.encoder = MLP(
|
||||
in_channels=self.config.action_feature.shape[0]
|
||||
* self.config.action_chunk_size,
|
||||
in_channels=self.config.action_feature.shape[0] * self.config.action_chunk_size,
|
||||
hidden_channels=[
|
||||
config.vqvae_enc_hidden_dim,
|
||||
config.vqvae_enc_hidden_dim,
|
||||
@@ -925,13 +842,9 @@ class VqVae(nn.Module):
|
||||
# given latent vector, this function outputs the decoded action.
|
||||
output = self.decoder(latent)
|
||||
if self.config.action_chunk_size == 1:
|
||||
return einops.rearrange(
|
||||
output, "N (T A) -> N T A", A=self.config.action_feature.shape[0]
|
||||
)
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
|
||||
else:
|
||||
return einops.rearrange(
|
||||
output, "N (T A) -> N T A", A=self.config.action_feature.shape[0]
|
||||
)
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
|
||||
|
||||
def get_code(self, state):
|
||||
# in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181)
|
||||
|
||||
@@ -123,15 +123,9 @@ class CausalSelfAttention(nn.Module):
|
||||
|
||||
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
||||
q, k, v = self.c_attn(x).split(self.gpt_hidden_dim, dim=2)
|
||||
k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
|
||||
1, 2
|
||||
) # (B, nh, T, hs)
|
||||
q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
|
||||
1, 2
|
||||
) # (B, nh, T, hs)
|
||||
v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
|
||||
1, 2
|
||||
) # (B, nh, T, hs)
|
||||
k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
|
||||
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||
@@ -139,9 +133,7 @@ class CausalSelfAttention(nn.Module):
|
||||
att = F.softmax(att, dim=-1)
|
||||
att = self.attn_dropout(att)
|
||||
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
||||
y = (
|
||||
y.transpose(1, 2).contiguous().view(B, T, C)
|
||||
) # re-assemble all head outputs side by side
|
||||
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
||||
|
||||
# output projection
|
||||
y = self.resid_dropout(self.c_proj(y))
|
||||
@@ -197,16 +189,12 @@ class GPT(nn.Module):
|
||||
"ln_f": nn.LayerNorm(config.gpt_hidden_dim),
|
||||
}
|
||||
)
|
||||
self.lm_head = nn.Linear(
|
||||
config.gpt_hidden_dim, config.gpt_output_dim, bias=False
|
||||
)
|
||||
self.lm_head = nn.Linear(config.gpt_hidden_dim, config.gpt_output_dim, bias=False)
|
||||
# init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
|
||||
self.apply(self._init_weights)
|
||||
for pn, p in self.named_parameters():
|
||||
if pn.endswith("c_proj.weight"):
|
||||
torch.nn.init.normal_(
|
||||
p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer)
|
||||
)
|
||||
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer))
|
||||
|
||||
# report number of parameters
|
||||
n_params = sum(p.numel() for p in self.parameters())
|
||||
@@ -220,17 +208,11 @@ class GPT(nn.Module):
|
||||
)
|
||||
|
||||
# positional encodings that are added to the input embeddings
|
||||
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(
|
||||
0
|
||||
) # shape (1, t)
|
||||
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
|
||||
|
||||
# forward the GPT model itself
|
||||
tok_emb = self.transformer.wte(
|
||||
input
|
||||
) # token embeddings of shape (b, t, gpt_hidden_dim)
|
||||
pos_emb = self.transformer.wpe(
|
||||
pos
|
||||
) # position embeddings of shape (1, t, gpt_hidden_dim)
|
||||
tok_emb = self.transformer.wte(input) # token embeddings of shape (b, t, gpt_hidden_dim)
|
||||
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim)
|
||||
x = self.transformer.drop(tok_emb + pos_emb)
|
||||
for block in self.transformer.h:
|
||||
x = block(x)
|
||||
@@ -255,9 +237,7 @@ class GPT(nn.Module):
|
||||
# but want to use a smaller block size for some smaller, simpler model
|
||||
assert gpt_block_size <= self.config.gpt_block_size
|
||||
self.config.gpt_block_size = gpt_block_size
|
||||
self.transformer.wpe.weight = nn.Parameter(
|
||||
self.transformer.wpe.weight[:gpt_block_size]
|
||||
)
|
||||
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:gpt_block_size])
|
||||
for block in self.transformer.h:
|
||||
block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size]
|
||||
|
||||
@@ -290,10 +270,8 @@ class GPT(nn.Module):
|
||||
param_dict = dict(self.named_parameters())
|
||||
inter_params = decay & no_decay
|
||||
union_params = decay | no_decay
|
||||
assert len(inter_params) == 0, (
|
||||
"parameters {} made it into both decay/no_decay sets!".format(
|
||||
str(inter_params)
|
||||
)
|
||||
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
|
||||
str(inter_params)
|
||||
)
|
||||
assert len(param_dict.keys() - union_params) == 0, (
|
||||
"parameters {} were not separated into either decay/no_decay set!".format(
|
||||
@@ -390,12 +368,8 @@ class ResidualVQ(nn.Module):
|
||||
codebook_input_dim = codebook_dim * heads
|
||||
|
||||
requires_projection = codebook_input_dim != dim
|
||||
self.project_in = (
|
||||
nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
|
||||
)
|
||||
self.project_out = (
|
||||
nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
|
||||
)
|
||||
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
|
||||
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
|
||||
|
||||
self.num_quantizers = num_quantizers
|
||||
|
||||
@@ -477,9 +451,7 @@ class ResidualVQ(nn.Module):
|
||||
|
||||
return all_codes
|
||||
|
||||
def forward(
|
||||
self, x, indices=None, return_all_codes=False, sample_codebook_temp=None
|
||||
):
|
||||
def forward(self, x, indices=None, return_all_codes=False, sample_codebook_temp=None):
|
||||
"""
|
||||
For given input tensor x, this function will return the quantized output, the indices of the quantized output, and the loss.
|
||||
First, the input tensor x is projected to the codebook dimension. Then, the input tensor x is passed through Nq layers of VectorQuantize.
|
||||
@@ -508,17 +480,13 @@ class ResidualVQ(nn.Module):
|
||||
)
|
||||
ce_losses = []
|
||||
|
||||
should_quantize_dropout = (
|
||||
self.training and self.quantize_dropout and not return_loss
|
||||
)
|
||||
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
|
||||
|
||||
# sample a layer index at which to dropout further residual quantization
|
||||
# also prepare null indices and loss
|
||||
|
||||
if should_quantize_dropout:
|
||||
rand_quantize_dropout_index = randrange(
|
||||
self.quantize_dropout_cutoff_index, num_quant
|
||||
)
|
||||
rand_quantize_dropout_index = randrange(self.quantize_dropout_cutoff_index, num_quant)
|
||||
|
||||
if quant_dropout_multiple_of != 1:
|
||||
rand_quantize_dropout_index = (
|
||||
@@ -527,23 +495,14 @@ class ResidualVQ(nn.Module):
|
||||
- 1
|
||||
)
|
||||
|
||||
null_indices_shape = (
|
||||
(x.shape[0], *x.shape[-2:])
|
||||
if self.accept_image_fmap
|
||||
else tuple(x.shape[:2])
|
||||
)
|
||||
null_indices = torch.full(
|
||||
null_indices_shape, -1.0, device=device, dtype=torch.long
|
||||
)
|
||||
null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2])
|
||||
null_indices = torch.full(null_indices_shape, -1.0, device=device, dtype=torch.long)
|
||||
null_loss = torch.full((1,), 0.0, device=device, dtype=x.dtype)
|
||||
|
||||
# go through the layers
|
||||
|
||||
for quantizer_index, layer in enumerate(self.layers):
|
||||
if (
|
||||
should_quantize_dropout
|
||||
and quantizer_index > rand_quantize_dropout_index
|
||||
):
|
||||
if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
|
||||
all_indices.append(null_indices)
|
||||
all_losses.append(null_loss)
|
||||
continue
|
||||
@@ -583,9 +542,7 @@ class ResidualVQ(nn.Module):
|
||||
|
||||
# stack all losses and indices
|
||||
|
||||
all_losses, all_indices = map(
|
||||
partial(torch.stack, dim=-1), (all_losses, all_indices)
|
||||
)
|
||||
all_losses, all_indices = map(partial(torch.stack, dim=-1), (all_losses, all_indices))
|
||||
|
||||
ret = (quantized_out, all_indices, all_losses)
|
||||
|
||||
@@ -645,12 +602,8 @@ class VectorQuantize(nn.Module):
|
||||
codebook_input_dim = codebook_dim * heads
|
||||
|
||||
requires_projection = codebook_input_dim != dim
|
||||
self.project_in = (
|
||||
nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
|
||||
)
|
||||
self.project_out = (
|
||||
nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
|
||||
)
|
||||
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
|
||||
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
|
||||
|
||||
self.eps = eps
|
||||
self.commitment_weight = commitment_weight
|
||||
@@ -664,14 +617,10 @@ class VectorQuantize(nn.Module):
|
||||
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
|
||||
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
|
||||
|
||||
assert not (ema_update and learnable_codebook), (
|
||||
"learnable codebook not compatible with EMA update"
|
||||
)
|
||||
assert not (ema_update and learnable_codebook), "learnable codebook not compatible with EMA update"
|
||||
|
||||
assert 0 <= sync_update_v <= 1.0
|
||||
assert not (sync_update_v > 0.0 and not learnable_codebook), (
|
||||
"learnable codebook must be turned on"
|
||||
)
|
||||
assert not (sync_update_v > 0.0 and not learnable_codebook), "learnable codebook must be turned on"
|
||||
|
||||
self.sync_update_v = sync_update_v
|
||||
|
||||
@@ -683,9 +632,7 @@ class VectorQuantize(nn.Module):
|
||||
)
|
||||
|
||||
if sync_codebook is None:
|
||||
sync_codebook = (
|
||||
distributed.is_initialized() and distributed.get_world_size() > 1
|
||||
)
|
||||
sync_codebook = distributed.is_initialized() and distributed.get_world_size() > 1
|
||||
|
||||
codebook_kwargs = {
|
||||
"dim": codebook_dim,
|
||||
@@ -850,17 +797,11 @@ class VectorQuantize(nn.Module):
|
||||
|
||||
# quantize again
|
||||
|
||||
quantize, embed_ind, distances = self._codebook(
|
||||
x, **codebook_forward_kwargs
|
||||
)
|
||||
quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)
|
||||
|
||||
if self.training:
|
||||
# determine code to use for commitment loss
|
||||
maybe_detach = (
|
||||
torch.detach
|
||||
if not self.learnable_codebook or freeze_codebook
|
||||
else identity
|
||||
)
|
||||
maybe_detach = torch.detach if not self.learnable_codebook or freeze_codebook else identity
|
||||
|
||||
commit_quantize = maybe_detach(quantize)
|
||||
|
||||
@@ -870,9 +811,7 @@ class VectorQuantize(nn.Module):
|
||||
|
||||
if self.sync_update_v > 0.0:
|
||||
# (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
|
||||
quantize = quantize + self.sync_update_v * (
|
||||
quantize - quantize.detach()
|
||||
)
|
||||
quantize = quantize + self.sync_update_v * (quantize - quantize.detach())
|
||||
|
||||
# function for calculating cross entropy loss to distance matrix
|
||||
# used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss
|
||||
@@ -905,9 +844,7 @@ class VectorQuantize(nn.Module):
|
||||
embed_ind = rearrange(embed_ind, "1 (b h) n -> b n h", h=heads)
|
||||
|
||||
if self.accept_image_fmap:
|
||||
embed_ind = rearrange(
|
||||
embed_ind, "b (h w) ... -> b h w ...", h=height, w=width
|
||||
)
|
||||
embed_ind = rearrange(embed_ind, "b (h w) ... -> b h w ...", h=height, w=width)
|
||||
|
||||
if only_one:
|
||||
embed_ind = rearrange(embed_ind, "b 1 -> b")
|
||||
@@ -961,12 +898,8 @@ class VectorQuantize(nn.Module):
|
||||
|
||||
num_codes = codebook.shape[-2]
|
||||
|
||||
if (
|
||||
self.orthogonal_reg_max_codes is not None
|
||||
) and num_codes > self.orthogonal_reg_max_codes:
|
||||
rand_ids = torch.randperm(num_codes, device=device)[
|
||||
: self.orthogonal_reg_max_codes
|
||||
]
|
||||
if (self.orthogonal_reg_max_codes is not None) and num_codes > self.orthogonal_reg_max_codes:
|
||||
rand_ids = torch.randperm(num_codes, device=device)[: self.orthogonal_reg_max_codes]
|
||||
codebook = codebook[:, rand_ids]
|
||||
|
||||
orthogonal_reg_loss = orthogonal_loss_fn(codebook)
|
||||
@@ -998,9 +931,7 @@ class VectorQuantize(nn.Module):
|
||||
# if masking, only return quantized for where mask has True
|
||||
|
||||
if mask is not None:
|
||||
quantize = torch.where(
|
||||
rearrange(mask, "... -> ... 1"), quantize, orig_input
|
||||
)
|
||||
quantize = torch.where(rearrange(mask, "... -> ... 1"), quantize, orig_input)
|
||||
|
||||
return quantize, embed_ind, loss
|
||||
|
||||
@@ -1110,9 +1041,7 @@ def sample_vectors(samples, num):
|
||||
|
||||
|
||||
def batched_sample_vectors(samples, num):
|
||||
return torch.stack(
|
||||
[sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0
|
||||
)
|
||||
return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0)
|
||||
|
||||
|
||||
def pad_shape(shape, size, dim=0):
|
||||
@@ -1163,9 +1092,7 @@ def sample_vectors_distributed(local_samples, num):
|
||||
all_num_samples = all_gather_sizes(local_samples, dim=0)
|
||||
|
||||
if rank == 0:
|
||||
samples_per_rank = sample_multinomial(
|
||||
num, all_num_samples / all_num_samples.sum()
|
||||
)
|
||||
samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum())
|
||||
else:
|
||||
samples_per_rank = torch.empty_like(all_num_samples)
|
||||
|
||||
@@ -1278,9 +1205,7 @@ class EuclideanCodebook(nn.Module):
|
||||
self.eps = eps
|
||||
self.threshold_ema_dead_code = threshold_ema_dead_code
|
||||
self.reset_cluster_size = (
|
||||
reset_cluster_size
|
||||
if (reset_cluster_size is not None)
|
||||
else threshold_ema_dead_code
|
||||
reset_cluster_size if (reset_cluster_size is not None) else threshold_ema_dead_code
|
||||
)
|
||||
|
||||
assert callable(gumbel_sample)
|
||||
@@ -1291,14 +1216,8 @@ class EuclideanCodebook(nn.Module):
|
||||
"kmeans init is not compatible with multiple codebooks in distributed environment for now"
|
||||
)
|
||||
|
||||
self.sample_fn = (
|
||||
sample_vectors_distributed
|
||||
if use_ddp and sync_kmeans
|
||||
else batched_sample_vectors
|
||||
)
|
||||
self.kmeans_all_reduce_fn = (
|
||||
distributed.all_reduce if use_ddp and sync_kmeans else noop
|
||||
)
|
||||
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
|
||||
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
|
||||
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
|
||||
|
||||
self.register_buffer("initted", torch.Tensor([not kmeans_init]))
|
||||
@@ -1437,9 +1356,7 @@ class EuclideanCodebook(nn.Module):
|
||||
distributed.all_reduce(variance_number)
|
||||
batch_variance = variance_number / num_vectors
|
||||
|
||||
self.update_with_decay(
|
||||
"batch_variance", batch_variance, self.affine_param_batch_decay
|
||||
)
|
||||
self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay)
|
||||
|
||||
def replace(self, batch_samples, batch_mask):
|
||||
for ind, (samples, mask) in enumerate(
|
||||
@@ -1448,9 +1365,7 @@ class EuclideanCodebook(nn.Module):
|
||||
if not torch.any(mask):
|
||||
continue
|
||||
|
||||
sampled = self.sample_fn(
|
||||
rearrange(samples, "... -> 1 ..."), mask.sum().item()
|
||||
)
|
||||
sampled = self.sample_fn(rearrange(samples, "... -> 1 ..."), mask.sum().item())
|
||||
sampled = rearrange(sampled, "1 ... -> ...")
|
||||
|
||||
self.embed.data[ind][mask] = sampled
|
||||
@@ -1474,9 +1389,7 @@ class EuclideanCodebook(nn.Module):
|
||||
def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
|
||||
needs_codebook_dim = x.ndim < 4
|
||||
sample_codebook_temp = (
|
||||
sample_codebook_temp
|
||||
if (sample_codebook_temp is not None)
|
||||
else self.sample_codebook_temp
|
||||
sample_codebook_temp if (sample_codebook_temp is not None) else self.sample_codebook_temp
|
||||
)
|
||||
|
||||
x = x.float()
|
||||
@@ -1504,9 +1417,7 @@ class EuclideanCodebook(nn.Module):
|
||||
if self.affine_param:
|
||||
codebook_std = self.codebook_variance.clamp(min=1e-5).sqrt()
|
||||
batch_std = self.batch_variance.clamp(min=1e-5).sqrt()
|
||||
embed = (embed - self.codebook_mean) * (
|
||||
batch_std / codebook_std
|
||||
) + self.batch_mean
|
||||
embed = (embed - self.codebook_mean) * (batch_std / codebook_std) + self.batch_mean
|
||||
|
||||
dist = -cdist(flatten, embed)
|
||||
|
||||
@@ -1524,9 +1435,7 @@ class EuclideanCodebook(nn.Module):
|
||||
|
||||
if self.training and self.ema_update and not freeze_codebook:
|
||||
if self.affine_param:
|
||||
flatten = (flatten - self.batch_mean) * (
|
||||
codebook_std / batch_std
|
||||
) + self.codebook_mean
|
||||
flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean
|
||||
|
||||
if mask is not None:
|
||||
embed_onehot[~mask] = 0.0
|
||||
@@ -1549,9 +1458,7 @@ class EuclideanCodebook(nn.Module):
|
||||
self.expire_codes_(x)
|
||||
|
||||
if needs_codebook_dim:
|
||||
quantize, embed_ind = tuple(
|
||||
rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind)
|
||||
)
|
||||
quantize, embed_ind = tuple(rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind))
|
||||
|
||||
dist = unpack_one(dist, ps, "h * d")
|
||||
|
||||
|
||||
@@ -57,9 +57,7 @@ class OpenCVCameraConfig(CameraConfig):
|
||||
self.channels = 3
|
||||
|
||||
if self.rotation not in [-90, None, 90, 180]:
|
||||
raise ValueError(
|
||||
f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})"
|
||||
)
|
||||
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("intelrealsense")
|
||||
@@ -104,12 +102,8 @@ class IntelRealSenseCameraConfig(CameraConfig):
|
||||
|
||||
self.channels = 3
|
||||
|
||||
at_least_one_is_not_none = (
|
||||
self.fps is not None or self.width is not None or self.height is not None
|
||||
)
|
||||
at_least_one_is_none = (
|
||||
self.fps is None or self.width is None or self.height is None
|
||||
)
|
||||
at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None
|
||||
at_least_one_is_none = self.fps is None or self.width is None or self.height is None
|
||||
if at_least_one_is_not_none and at_least_one_is_none:
|
||||
raise ValueError(
|
||||
"For `fps`, `width` and `height`, either all of them need to be set, or none of them, "
|
||||
@@ -117,6 +111,4 @@ class IntelRealSenseCameraConfig(CameraConfig):
|
||||
)
|
||||
|
||||
if self.rotation not in [-90, None, 90, 180]:
|
||||
raise ValueError(
|
||||
f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})"
|
||||
)
|
||||
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
|
||||
|
||||
@@ -79,9 +79,7 @@ def save_image(img_array, serial_number, frame_index, images_dir):
|
||||
img.save(str(path), quality=100)
|
||||
logging.info(f"Saved image: {path}")
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Failed to save image for camera {serial_number} frame {frame_index}: {e}"
|
||||
)
|
||||
logging.error(f"Failed to save image for camera {serial_number} frame {frame_index}: {e}")
|
||||
|
||||
|
||||
def save_images_from_cameras(
|
||||
@@ -159,9 +157,7 @@ def save_images_from_cameras(
|
||||
if time.perf_counter() - start_time > record_time_s:
|
||||
break
|
||||
|
||||
print(
|
||||
f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}"
|
||||
)
|
||||
print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
|
||||
|
||||
frame_index += 1
|
||||
finally:
|
||||
@@ -279,9 +275,7 @@ class IntelRealSenseCamera:
|
||||
f"Multiple {name} cameras have been detected. Please use their serial number to instantiate them."
|
||||
)
|
||||
|
||||
name_to_serial_dict = {
|
||||
cam["name"]: cam["serial_number"] for cam in camera_infos
|
||||
}
|
||||
name_to_serial_dict = {cam["name"]: cam["serial_number"] for cam in camera_infos}
|
||||
cam_sn = name_to_serial_dict[name]
|
||||
|
||||
return cam_sn
|
||||
@@ -353,9 +347,7 @@ class IntelRealSenseCamera:
|
||||
actual_height = color_profile.height()
|
||||
|
||||
# Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
|
||||
if self.fps is not None and not math.isclose(
|
||||
self.fps, actual_fps, rel_tol=1e-3
|
||||
):
|
||||
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
|
||||
# Using `OSError` since it's a broad that encompasses issues related to device communication
|
||||
raise OSError(
|
||||
f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}."
|
||||
@@ -375,9 +367,7 @@ class IntelRealSenseCamera:
|
||||
|
||||
self.is_connected = True
|
||||
|
||||
def read(
|
||||
self, temporary_color: str | None = None
|
||||
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
|
||||
def read(self, temporary_color: str | None = None) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
|
||||
"""Read a frame from the camera returned in the format height x width x channels (e.g. 480 x 640 x 3)
|
||||
of type `np.uint8`, contrarily to the pytorch format which is float channel first.
|
||||
|
||||
@@ -404,15 +394,11 @@ class IntelRealSenseCamera:
|
||||
color_frame = frame.get_color_frame()
|
||||
|
||||
if not color_frame:
|
||||
raise OSError(
|
||||
f"Can't capture color image from IntelRealSenseCamera({self.serial_number})."
|
||||
)
|
||||
raise OSError(f"Can't capture color image from IntelRealSenseCamera({self.serial_number}).")
|
||||
|
||||
color_image = np.asanyarray(color_frame.get_data())
|
||||
|
||||
requested_color_mode = (
|
||||
self.color_mode if temporary_color is None else temporary_color
|
||||
)
|
||||
requested_color_mode = self.color_mode if temporary_color is None else temporary_color
|
||||
if requested_color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided."
|
||||
@@ -440,9 +426,7 @@ class IntelRealSenseCamera:
|
||||
if self.use_depth:
|
||||
depth_frame = frame.get_depth_frame()
|
||||
if not depth_frame:
|
||||
raise OSError(
|
||||
f"Can't capture depth image from IntelRealSenseCamera({self.serial_number})."
|
||||
)
|
||||
raise OSError(f"Can't capture depth image from IntelRealSenseCamera({self.serial_number}).")
|
||||
|
||||
depth_map = np.asanyarray(depth_frame.get_data())
|
||||
|
||||
@@ -484,9 +468,7 @@ class IntelRealSenseCamera:
|
||||
# TODO(rcadene, aliberts): intelrealsense has diverged compared to opencv over here
|
||||
num_tries += 1
|
||||
time.sleep(1 / self.fps)
|
||||
if num_tries > self.fps and (
|
||||
self.thread.ident is None or not self.thread.is_alive()
|
||||
):
|
||||
if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()):
|
||||
raise Exception(
|
||||
"The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called."
|
||||
)
|
||||
|
||||
@@ -45,14 +45,10 @@ from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
MAX_OPENCV_INDEX = 60
|
||||
|
||||
|
||||
def find_cameras(
|
||||
raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False
|
||||
) -> list[dict]:
|
||||
def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
|
||||
cameras = []
|
||||
if platform.system() == "Linux":
|
||||
print(
|
||||
"Linux detected. Finding available camera indices through scanning '/dev/video*' ports"
|
||||
)
|
||||
print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports")
|
||||
possible_ports = [str(port) for port in Path("/dev").glob("video*")]
|
||||
ports = _find_cameras(possible_ports, mock=mock)
|
||||
for port in ports:
|
||||
@@ -144,9 +140,7 @@ def save_images_from_cameras(
|
||||
print("Connecting cameras")
|
||||
cameras = []
|
||||
for cam_idx in camera_ids:
|
||||
config = OpenCVCameraConfig(
|
||||
camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock
|
||||
)
|
||||
config = OpenCVCameraConfig(camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock)
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect()
|
||||
print(
|
||||
@@ -186,9 +180,7 @@ def save_images_from_cameras(
|
||||
dt_s = time.perf_counter() - now
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
print(
|
||||
f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}"
|
||||
)
|
||||
print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
|
||||
|
||||
if time.perf_counter() - start_time > record_time_s:
|
||||
break
|
||||
@@ -245,16 +237,12 @@ class OpenCVCamera:
|
||||
if platform.system() == "Linux":
|
||||
if isinstance(self.camera_index, int):
|
||||
self.port = Path(f"/dev/video{self.camera_index}")
|
||||
elif isinstance(self.camera_index, str) and is_valid_unix_path(
|
||||
self.camera_index
|
||||
):
|
||||
elif isinstance(self.camera_index, str) and is_valid_unix_path(self.camera_index):
|
||||
self.port = Path(self.camera_index)
|
||||
# Retrieve the camera index from a potentially symlinked path
|
||||
self.camera_index = get_camera_index_from_unix_port(self.port)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Please check the provided camera_index: {self.camera_index}"
|
||||
)
|
||||
raise ValueError(f"Please check the provided camera_index: {self.camera_index}")
|
||||
|
||||
# Store the raw (capture) resolution from the config.
|
||||
self.capture_width = config.width
|
||||
@@ -295,9 +283,7 @@ class OpenCVCamera:
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected:
|
||||
raise RobotDeviceAlreadyConnectedError(
|
||||
f"OpenCVCamera({self.camera_index}) is already connected."
|
||||
)
|
||||
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
|
||||
|
||||
if self.mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
@@ -318,11 +304,7 @@ class OpenCVCamera:
|
||||
else cv2.CAP_ANY
|
||||
)
|
||||
|
||||
camera_idx = (
|
||||
f"/dev/video{self.camera_index}"
|
||||
if platform.system() == "Linux"
|
||||
else self.camera_index
|
||||
)
|
||||
camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index
|
||||
# First create a temporary camera trying to access `camera_index`,
|
||||
# and verify it is a valid camera by calling `isOpened`.
|
||||
tmp_camera = cv2.VideoCapture(camera_idx, backend)
|
||||
@@ -362,9 +344,7 @@ class OpenCVCamera:
|
||||
actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
||||
|
||||
# Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
|
||||
if self.fps is not None and not math.isclose(
|
||||
self.fps, actual_fps, rel_tol=1e-3
|
||||
):
|
||||
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
|
||||
# Using `OSError` since it's a broad that encompasses issues related to device communication
|
||||
raise OSError(
|
||||
f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}."
|
||||
@@ -406,9 +386,7 @@ class OpenCVCamera:
|
||||
if not ret:
|
||||
raise OSError(f"Can't capture color image from camera {self.camera_index}.")
|
||||
|
||||
requested_color_mode = (
|
||||
self.color_mode if temporary_color_mode is None else temporary_color_mode
|
||||
)
|
||||
requested_color_mode = self.color_mode if temporary_color_mode is None else temporary_color_mode
|
||||
|
||||
if requested_color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
|
||||
@@ -93,9 +93,7 @@ class RecordControlConfig(ControlConfig):
|
||||
policy_path = parser.get_path_arg("control.policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("control.policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(
|
||||
policy_path, cli_overrides=cli_overrides
|
||||
)
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
|
||||
|
||||
|
||||
@@ -39,9 +39,7 @@ from lerobot.common.robot_devices.utils import busy_wait
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, has_method
|
||||
|
||||
|
||||
def log_control_info(
|
||||
robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None
|
||||
):
|
||||
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
|
||||
log_items = []
|
||||
if episode_index is not None:
|
||||
log_items.append(f"ep:{episode_index}")
|
||||
@@ -108,9 +106,7 @@ def predict_action(observation, policy, device, use_amp):
|
||||
observation = copy(observation)
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type)
|
||||
if device.type == "cuda" and use_amp
|
||||
else nullcontext(),
|
||||
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||||
):
|
||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||
for name in observation:
|
||||
@@ -166,9 +162,7 @@ def init_keyboard_listener(assign_rewards=False):
|
||||
print("Right arrow key pressed. Exiting loop...")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
print(
|
||||
"Left arrow key pressed. Exiting loop and rerecord the last episode..."
|
||||
)
|
||||
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
@@ -262,9 +256,7 @@ def control_loop(
|
||||
raise ValueError("You need to provide a task as argument in `single_task`.")
|
||||
|
||||
if dataset is not None and fps is not None and dataset.fps != fps:
|
||||
raise ValueError(
|
||||
f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps})."
|
||||
)
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
@@ -302,9 +294,7 @@ def control_loop(
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(
|
||||
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
|
||||
)
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
|
||||
if fps is not None:
|
||||
@@ -392,14 +382,11 @@ def sanity_check_dataset_robot_compatibility(
|
||||
|
||||
mismatches = []
|
||||
for field, dataset_value, present_value in fields:
|
||||
diff = DeepDiff(
|
||||
dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"]
|
||||
)
|
||||
diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"])
|
||||
if diff:
|
||||
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
|
||||
|
||||
if mismatches:
|
||||
raise ValueError(
|
||||
"Dataset metadata compatibility check failed with mismatches:\n"
|
||||
+ "\n".join(mismatches)
|
||||
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
|
||||
)
|
||||
|
||||
@@ -161,9 +161,7 @@ NUM_READ_RETRY = 10
|
||||
NUM_WRITE_RETRY = 10
|
||||
|
||||
|
||||
def convert_degrees_to_steps(
|
||||
degrees: float | np.ndarray, models: str | list[str]
|
||||
) -> np.ndarray:
|
||||
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
|
||||
"""This function converts the degree range to the step range for indicating motors rotation.
|
||||
It assumes a motor achieves a full rotation by going from -180 degree position to +180.
|
||||
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
|
||||
@@ -389,9 +387,7 @@ class DynamixelMotorsBus:
|
||||
indices = []
|
||||
for idx in tqdm.tqdm(possible_ids):
|
||||
try:
|
||||
present_idx = self.read_with_motor_ids(
|
||||
self.motor_models, [idx], "ID", num_retry=num_retry
|
||||
)[0]
|
||||
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
|
||||
except ConnectionError:
|
||||
continue
|
||||
|
||||
@@ -407,9 +403,7 @@ class DynamixelMotorsBus:
|
||||
def set_bus_baudrate(self, baudrate):
|
||||
present_bus_baudrate = self.port_handler.getBaudRate()
|
||||
if present_bus_baudrate != baudrate:
|
||||
print(
|
||||
f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}."
|
||||
)
|
||||
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
|
||||
self.port_handler.setBaudRate(baudrate)
|
||||
|
||||
if self.port_handler.getBaudRate() != baudrate:
|
||||
@@ -430,9 +424,7 @@ class DynamixelMotorsBus:
|
||||
def set_calibration(self, calibration: dict[str, list]):
|
||||
self.calibration = calibration
|
||||
|
||||
def apply_calibration_autocorrect(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""This function applies the calibration, automatically detects out of range errors for motors values and attempts to correct.
|
||||
|
||||
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
|
||||
@@ -445,9 +437,7 @@ class DynamixelMotorsBus:
|
||||
values = self.apply_calibration(values, motor_names)
|
||||
return values
|
||||
|
||||
def apply_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
|
||||
a "zero position" at 0 degree.
|
||||
|
||||
@@ -522,9 +512,7 @@ class DynamixelMotorsBus:
|
||||
|
||||
return values
|
||||
|
||||
def autocorrect_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""This function automatically detects issues with values of motors after calibration, and correct for these issues.
|
||||
|
||||
Some motors might have values outside of expected maximum bounds after calibration.
|
||||
@@ -566,23 +554,15 @@ class DynamixelMotorsBus:
|
||||
values[i] *= -1
|
||||
|
||||
# Convert from initial range to range [-180, 180] degrees
|
||||
calib_val = (
|
||||
(values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
)
|
||||
in_range = (calib_val > LOWER_BOUND_DEGREE) and (
|
||||
calib_val < UPPER_BOUND_DEGREE
|
||||
)
|
||||
calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
|
||||
|
||||
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees
|
||||
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
|
||||
# (- (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= ((resolution // 2) - values[i] - homing_offset) / resolution
|
||||
low_factor = (
|
||||
-(resolution // 2) - values[i] - homing_offset
|
||||
) / resolution
|
||||
upp_factor = (
|
||||
(resolution // 2) - values[i] - homing_offset
|
||||
) / resolution
|
||||
low_factor = (-(resolution // 2) - values[i] - homing_offset) / resolution
|
||||
upp_factor = ((resolution // 2) - values[i] - homing_offset) / resolution
|
||||
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
start_pos = self.calibration["start_pos"][calib_idx]
|
||||
@@ -590,9 +570,7 @@ class DynamixelMotorsBus:
|
||||
|
||||
# Convert from initial range to range [0, 100] in %
|
||||
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
|
||||
in_range = (calib_val > LOWER_BOUND_LINEAR) and (
|
||||
calib_val < UPPER_BOUND_LINEAR
|
||||
)
|
||||
in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
|
||||
|
||||
# Solve this inequality to find the factor to shift the range into [0, 100] %
|
||||
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
|
||||
@@ -608,27 +586,19 @@ class DynamixelMotorsBus:
|
||||
factor = math.ceil(low_factor)
|
||||
|
||||
if factor > upp_factor:
|
||||
raise ValueError(
|
||||
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
|
||||
)
|
||||
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
|
||||
else:
|
||||
factor = math.ceil(upp_factor)
|
||||
|
||||
if factor > low_factor:
|
||||
raise ValueError(
|
||||
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
|
||||
)
|
||||
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
|
||||
|
||||
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
||||
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
|
||||
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
out_of_range_str = (
|
||||
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
)
|
||||
in_range_str = (
|
||||
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
)
|
||||
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
|
||||
logging.warning(
|
||||
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
|
||||
@@ -638,9 +608,7 @@ class DynamixelMotorsBus:
|
||||
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
||||
self.calibration["homing_offset"][calib_idx] += resolution * factor
|
||||
|
||||
def revert_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""Inverse of `apply_calibration`."""
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
@@ -679,9 +647,7 @@ class DynamixelMotorsBus:
|
||||
values = np.round(values).astype(np.int32)
|
||||
return values
|
||||
|
||||
def read_with_motor_ids(
|
||||
self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY
|
||||
):
|
||||
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
||||
if self.mock:
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
@@ -783,9 +749,7 @@ class DynamixelMotorsBus:
|
||||
values = self.apply_calibration_autocorrect(values, motor_names)
|
||||
|
||||
# log the number of seconds it took to read the data from the motors
|
||||
delta_ts_name = get_log_name(
|
||||
"delta_timestamp_s", "read", data_name, motor_names
|
||||
)
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# log the utc time at which the data was received
|
||||
@@ -794,9 +758,7 @@ class DynamixelMotorsBus:
|
||||
|
||||
return values
|
||||
|
||||
def write_with_motor_ids(
|
||||
self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY
|
||||
):
|
||||
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
||||
if self.mock:
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
@@ -891,9 +853,7 @@ class DynamixelMotorsBus:
|
||||
)
|
||||
|
||||
# log the number of seconds it took to write the data to the motors
|
||||
delta_ts_name = get_log_name(
|
||||
"delta_timestamp_s", "write", data_name, motor_names
|
||||
)
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# TODO(rcadene): should we log the time before sending the write command?
|
||||
|
||||
@@ -140,9 +140,7 @@ NUM_READ_RETRY = 20
|
||||
NUM_WRITE_RETRY = 20
|
||||
|
||||
|
||||
def convert_degrees_to_steps(
|
||||
degrees: float | np.ndarray, models: str | list[str]
|
||||
) -> np.ndarray:
|
||||
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
|
||||
"""This function converts the degree range to the step range for indicating motors rotation.
|
||||
It assumes a motor achieves a full rotation by going from -180 degree position to +180.
|
||||
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
|
||||
@@ -370,9 +368,7 @@ class FeetechMotorsBus:
|
||||
indices = []
|
||||
for idx in tqdm.tqdm(possible_ids):
|
||||
try:
|
||||
present_idx = self.read_with_motor_ids(
|
||||
self.motor_models, [idx], "ID", num_retry=num_retry
|
||||
)[0]
|
||||
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
|
||||
except ConnectionError:
|
||||
continue
|
||||
|
||||
@@ -388,9 +384,7 @@ class FeetechMotorsBus:
|
||||
def set_bus_baudrate(self, baudrate):
|
||||
present_bus_baudrate = self.port_handler.getBaudRate()
|
||||
if present_bus_baudrate != baudrate:
|
||||
print(
|
||||
f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}."
|
||||
)
|
||||
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
|
||||
self.port_handler.setBaudRate(baudrate)
|
||||
|
||||
if self.port_handler.getBaudRate() != baudrate:
|
||||
@@ -411,9 +405,7 @@ class FeetechMotorsBus:
|
||||
def set_calibration(self, calibration: dict[str, list]):
|
||||
self.calibration = calibration
|
||||
|
||||
def apply_calibration_autocorrect(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""This function apply the calibration, automatically detects out of range errors for motors values and attempt to correct.
|
||||
|
||||
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
|
||||
@@ -426,9 +418,7 @@ class FeetechMotorsBus:
|
||||
values = self.apply_calibration(values, motor_names)
|
||||
return values
|
||||
|
||||
def apply_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
|
||||
a "zero position" at 0 degree.
|
||||
|
||||
@@ -502,9 +492,7 @@ class FeetechMotorsBus:
|
||||
|
||||
return values
|
||||
|
||||
def autocorrect_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""This function automatically detects issues with values of motors after calibration, and correct for these issues.
|
||||
|
||||
Some motors might have values outside of expected maximum bounds after calibration.
|
||||
@@ -543,26 +531,18 @@ class FeetechMotorsBus:
|
||||
values[i] *= -1
|
||||
|
||||
# Convert from initial range to range [-180, 180] degrees
|
||||
calib_val = (
|
||||
(values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
)
|
||||
in_range = (calib_val > LOWER_BOUND_DEGREE) and (
|
||||
calib_val < UPPER_BOUND_DEGREE
|
||||
)
|
||||
calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
|
||||
|
||||
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees
|
||||
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
|
||||
# (- HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= (HALF_TURN_DEGREE / 180 * (resolution // 2) - values[i] - homing_offset) / resolution
|
||||
low_factor = (
|
||||
-HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2)
|
||||
- values[i]
|
||||
- homing_offset
|
||||
-HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset
|
||||
) / resolution
|
||||
upp_factor = (
|
||||
HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2)
|
||||
- values[i]
|
||||
- homing_offset
|
||||
HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset
|
||||
) / resolution
|
||||
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
@@ -571,9 +551,7 @@ class FeetechMotorsBus:
|
||||
|
||||
# Convert from initial range to range [0, 100] in %
|
||||
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
|
||||
in_range = (calib_val > LOWER_BOUND_LINEAR) and (
|
||||
calib_val < UPPER_BOUND_LINEAR
|
||||
)
|
||||
in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
|
||||
|
||||
# Solve this inequality to find the factor to shift the range into [0, 100] %
|
||||
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
|
||||
@@ -589,27 +567,19 @@ class FeetechMotorsBus:
|
||||
factor = math.ceil(low_factor)
|
||||
|
||||
if factor > upp_factor:
|
||||
raise ValueError(
|
||||
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
|
||||
)
|
||||
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
|
||||
else:
|
||||
factor = math.ceil(upp_factor)
|
||||
|
||||
if factor > low_factor:
|
||||
raise ValueError(
|
||||
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
|
||||
)
|
||||
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
|
||||
|
||||
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
||||
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
|
||||
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
out_of_range_str = (
|
||||
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
)
|
||||
in_range_str = (
|
||||
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
)
|
||||
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
|
||||
logging.warning(
|
||||
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
|
||||
@@ -619,9 +589,7 @@ class FeetechMotorsBus:
|
||||
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
||||
self.calibration["homing_offset"][calib_idx] += resolution * factor
|
||||
|
||||
def revert_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""Inverse of `apply_calibration`."""
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
@@ -697,9 +665,7 @@ class FeetechMotorsBus:
|
||||
|
||||
return values
|
||||
|
||||
def read_with_motor_ids(
|
||||
self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY
|
||||
):
|
||||
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
||||
if self.mock:
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
else:
|
||||
@@ -808,9 +774,7 @@ class FeetechMotorsBus:
|
||||
values = self.apply_calibration_autocorrect(values, motor_names)
|
||||
|
||||
# log the number of seconds it took to read the data from the motors
|
||||
delta_ts_name = get_log_name(
|
||||
"delta_timestamp_s", "read", data_name, motor_names
|
||||
)
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# log the utc time at which the data was received
|
||||
@@ -819,9 +783,7 @@ class FeetechMotorsBus:
|
||||
|
||||
return values
|
||||
|
||||
def write_with_motor_ids(
|
||||
self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY
|
||||
):
|
||||
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
||||
if self.mock:
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
else:
|
||||
@@ -916,9 +878,7 @@ class FeetechMotorsBus:
|
||||
)
|
||||
|
||||
# log the number of seconds it took to write the data to the motors
|
||||
delta_ts_name = get_log_name(
|
||||
"delta_timestamp_s", "write", data_name, motor_names
|
||||
)
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# TODO(rcadene): should we log the time before sending the write command?
|
||||
|
||||
@@ -69,13 +69,9 @@ class ManipulatorRobotConfig(RobotConfig):
|
||||
if not cam.mock:
|
||||
cam.mock = True
|
||||
|
||||
if self.max_relative_target is not None and isinstance(
|
||||
self.max_relative_target, Sequence
|
||||
):
|
||||
if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence):
|
||||
for name in self.follower_arms:
|
||||
if len(self.follower_arms[name].motors) != len(
|
||||
self.max_relative_target
|
||||
):
|
||||
if len(self.follower_arms[name].motors) != len(self.max_relative_target):
|
||||
raise ValueError(
|
||||
f"len(max_relative_target)={len(self.max_relative_target)} but the follower arm with name {name} has "
|
||||
f"{len(self.follower_arms[name].motors)} motors. Please make sure that the "
|
||||
|
||||
@@ -24,7 +24,9 @@ from lerobot.common.robot_devices.motors.dynamixel import (
|
||||
)
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
|
||||
URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
|
||||
URL_TEMPLATE = (
|
||||
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
|
||||
)
|
||||
|
||||
# The following positions are provided in nominal degree range ]-180, +180[
|
||||
# For more info on these constants, see comments in the code where they get used.
|
||||
@@ -35,9 +37,7 @@ ROTATED_POSITION_DEGREE = 90
|
||||
def assert_drive_mode(drive_mode):
|
||||
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
|
||||
if not np.all(np.isin(drive_mode, [0, 1])):
|
||||
raise ValueError(
|
||||
f"`drive_mode` contains values other than 0 or 1: ({drive_mode})"
|
||||
)
|
||||
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
|
||||
|
||||
|
||||
def apply_drive_mode(position, drive_mode):
|
||||
@@ -78,16 +78,12 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
|
||||
```
|
||||
"""
|
||||
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
||||
raise ValueError(
|
||||
"To run calibration, the torque must be disabled on all motors."
|
||||
)
|
||||
raise ValueError("To run calibration, the torque must be disabled on all motors.")
|
||||
|
||||
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
|
||||
|
||||
print("\nMove arm to zero position")
|
||||
print(
|
||||
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")
|
||||
)
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero"))
|
||||
input("Press Enter to continue...")
|
||||
|
||||
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
|
||||
@@ -108,15 +104,10 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
|
||||
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
|
||||
# of the previous motor in the kinetic chain.
|
||||
print("\nMove arm to rotated target position")
|
||||
print(
|
||||
"See: "
|
||||
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")
|
||||
)
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
|
||||
input("Press Enter to continue...")
|
||||
|
||||
rotated_target_pos = convert_degrees_to_steps(
|
||||
ROTATED_POSITION_DEGREE, arm.motor_models
|
||||
)
|
||||
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
|
||||
|
||||
# Find drive mode by rotating each motor by a quarter of a turn.
|
||||
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
|
||||
@@ -125,15 +116,11 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
|
||||
|
||||
# Re-compute homing offset to take into account drive mode
|
||||
rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode)
|
||||
rotated_nearest_pos = compute_nearest_rounded_position(
|
||||
rotated_drived_pos, arm.motor_models
|
||||
)
|
||||
rotated_nearest_pos = compute_nearest_rounded_position(rotated_drived_pos, arm.motor_models)
|
||||
homing_offset = rotated_target_pos - rotated_nearest_pos
|
||||
|
||||
print("\nMove arm to rest position")
|
||||
print(
|
||||
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")
|
||||
)
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest"))
|
||||
input("Press Enter to continue...")
|
||||
print()
|
||||
|
||||
|
||||
@@ -26,7 +26,9 @@ from lerobot.common.robot_devices.motors.feetech import (
|
||||
)
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
|
||||
URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
|
||||
URL_TEMPLATE = (
|
||||
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
|
||||
)
|
||||
|
||||
# The following positions are provided in nominal degree range ]-180, +180[
|
||||
# For more info on these constants, see comments in the code where they get used.
|
||||
@@ -37,9 +39,7 @@ ROTATED_POSITION_DEGREE = 90
|
||||
def assert_drive_mode(drive_mode):
|
||||
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
|
||||
if not np.all(np.isin(drive_mode, [0, 1])):
|
||||
raise ValueError(
|
||||
f"`drive_mode` contains values other than 0 or 1: ({drive_mode})"
|
||||
)
|
||||
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
|
||||
|
||||
|
||||
def apply_drive_mode(position, drive_mode):
|
||||
@@ -140,9 +140,7 @@ def apply_offset(calib, offset):
|
||||
return calib
|
||||
|
||||
|
||||
def run_arm_auto_calibration(
|
||||
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
|
||||
):
|
||||
def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
|
||||
if robot_type == "so100":
|
||||
return run_arm_auto_calibration_so100(arm, robot_type, arm_name, arm_type)
|
||||
elif robot_type == "moss":
|
||||
@@ -151,27 +149,18 @@ def run_arm_auto_calibration(
|
||||
raise ValueError(robot_type)
|
||||
|
||||
|
||||
def run_arm_auto_calibration_so100(
|
||||
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
|
||||
):
|
||||
def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
|
||||
"""All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
|
||||
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
||||
raise ValueError(
|
||||
"To run calibration, the torque must be disabled on all motors."
|
||||
)
|
||||
raise ValueError("To run calibration, the torque must be disabled on all motors.")
|
||||
|
||||
if not (robot_type == "so100" and arm_type == "follower"):
|
||||
raise NotImplementedError(
|
||||
"Auto calibration only supports the follower of so100 arms for now."
|
||||
)
|
||||
raise NotImplementedError("Auto calibration only supports the follower of so100 arms for now.")
|
||||
|
||||
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
|
||||
|
||||
print("\nMove arm to initial position")
|
||||
print(
|
||||
"See: "
|
||||
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")
|
||||
)
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial"))
|
||||
input("Press Enter to continue...")
|
||||
|
||||
# Lower the acceleration of the motors (in [0,254])
|
||||
@@ -225,9 +214,7 @@ def run_arm_auto_calibration_so100(
|
||||
)
|
||||
calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024)
|
||||
|
||||
arm.write(
|
||||
"Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex"
|
||||
)
|
||||
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex")
|
||||
time.sleep(1)
|
||||
|
||||
def in_between_move_hook():
|
||||
@@ -261,13 +248,9 @@ def run_arm_auto_calibration_so100(
|
||||
"shoulder_lift",
|
||||
)
|
||||
time.sleep(2)
|
||||
arm.write(
|
||||
"Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex"
|
||||
)
|
||||
arm.write("Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex")
|
||||
time.sleep(2)
|
||||
arm.write(
|
||||
"Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex"
|
||||
)
|
||||
arm.write("Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex")
|
||||
time.sleep(2)
|
||||
arm.write("Goal_Position", round(calib["gripper"]["end_pos"]), "gripper")
|
||||
time.sleep(2)
|
||||
@@ -288,9 +271,7 @@ def run_arm_auto_calibration_so100(
|
||||
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"], "wrist_flex")
|
||||
time.sleep(1)
|
||||
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 2048, "elbow_flex")
|
||||
arm.write(
|
||||
"Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift"
|
||||
)
|
||||
arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift")
|
||||
time.sleep(1)
|
||||
arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan")
|
||||
time.sleep(1)
|
||||
@@ -319,27 +300,18 @@ def run_arm_auto_calibration_so100(
|
||||
return calib_dict
|
||||
|
||||
|
||||
def run_arm_auto_calibration_moss(
|
||||
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
|
||||
):
|
||||
def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
|
||||
"""All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
|
||||
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
||||
raise ValueError(
|
||||
"To run calibration, the torque must be disabled on all motors."
|
||||
)
|
||||
raise ValueError("To run calibration, the torque must be disabled on all motors.")
|
||||
|
||||
if not (robot_type == "moss" and arm_type == "follower"):
|
||||
raise NotImplementedError(
|
||||
"Auto calibration only supports the follower of moss arms for now."
|
||||
)
|
||||
raise NotImplementedError("Auto calibration only supports the follower of moss arms for now.")
|
||||
|
||||
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
|
||||
|
||||
print("\nMove arm to initial position")
|
||||
print(
|
||||
"See: "
|
||||
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")
|
||||
)
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial"))
|
||||
input("Press Enter to continue...")
|
||||
|
||||
# Lower the acceleration of the motors (in [0,254])
|
||||
@@ -423,12 +395,8 @@ def run_arm_auto_calibration_moss(
|
||||
|
||||
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
|
||||
time.sleep(1)
|
||||
arm.write(
|
||||
"Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift"
|
||||
)
|
||||
arm.write(
|
||||
"Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex"
|
||||
)
|
||||
arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift")
|
||||
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex")
|
||||
time.sleep(2)
|
||||
|
||||
calib_modes = []
|
||||
@@ -455,9 +423,7 @@ def run_arm_auto_calibration_moss(
|
||||
return calib_dict
|
||||
|
||||
|
||||
def run_arm_manual_calibration(
|
||||
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
|
||||
):
|
||||
def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
|
||||
"""This function ensures that a neural network trained on data collected on a given robot
|
||||
can work on another robot. For instance before calibration, setting a same goal position
|
||||
for each motor of two different robots will get two very different positions. But after calibration,
|
||||
@@ -480,16 +446,12 @@ def run_arm_manual_calibration(
|
||||
```
|
||||
"""
|
||||
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
||||
raise ValueError(
|
||||
"To run calibration, the torque must be disabled on all motors."
|
||||
)
|
||||
raise ValueError("To run calibration, the torque must be disabled on all motors.")
|
||||
|
||||
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
|
||||
|
||||
print("\nMove arm to zero position")
|
||||
print(
|
||||
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")
|
||||
)
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero"))
|
||||
input("Press Enter to continue...")
|
||||
|
||||
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
|
||||
@@ -509,15 +471,10 @@ def run_arm_manual_calibration(
|
||||
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
|
||||
# of the previous motor in the kinetic chain.
|
||||
print("\nMove arm to rotated target position")
|
||||
print(
|
||||
"See: "
|
||||
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")
|
||||
)
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
|
||||
input("Press Enter to continue...")
|
||||
|
||||
rotated_target_pos = convert_degrees_to_steps(
|
||||
ROTATED_POSITION_DEGREE, arm.motor_models
|
||||
)
|
||||
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
|
||||
|
||||
# Find drive mode by rotating each motor by a quarter of a turn.
|
||||
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
|
||||
@@ -529,9 +486,7 @@ def run_arm_manual_calibration(
|
||||
homing_offset = rotated_target_pos - rotated_drived_pos
|
||||
|
||||
print("\nMove arm to rest position")
|
||||
print(
|
||||
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")
|
||||
)
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest"))
|
||||
input("Press Enter to continue...")
|
||||
print()
|
||||
|
||||
|
||||
@@ -42,9 +42,7 @@ def run_camera_capture(cameras, images_lock, latest_images_dict, stop_event):
|
||||
local_dict = {}
|
||||
for name, cam in cameras.items():
|
||||
frame = cam.async_read()
|
||||
ret, buffer = cv2.imencode(
|
||||
".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90]
|
||||
)
|
||||
ret, buffer = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
|
||||
if ret:
|
||||
local_dict[name] = base64.b64encode(buffer).decode("utf-8")
|
||||
else:
|
||||
@@ -76,9 +74,7 @@ def calibrate_follower_arm(motors_bus, calib_dir_str):
|
||||
print(f"[INFO] Loaded calibration from {calib_file}")
|
||||
else:
|
||||
print("[INFO] Calibration file not found. Running manual calibration...")
|
||||
calibration = run_arm_manual_calibration(
|
||||
motors_bus, "lekiwi", "follower_arm", "follower"
|
||||
)
|
||||
calibration = run_arm_manual_calibration(motors_bus, "lekiwi", "follower_arm", "follower")
|
||||
print(f"[INFO] Calibration complete. Saving to {calib_file}")
|
||||
with open(calib_file, "w") as f:
|
||||
json.dump(calibration, f)
|
||||
@@ -174,9 +170,7 @@ def run_lekiwi(robot_config):
|
||||
f"[WARNING] Received {len(arm_positions)} arm positions, expected {len(arm_motor_ids)}"
|
||||
)
|
||||
else:
|
||||
for motor, pos in zip(
|
||||
arm_motor_ids, arm_positions, strict=False
|
||||
):
|
||||
for motor, pos in zip(arm_motor_ids, arm_positions, strict=False):
|
||||
motors_bus.write("Goal_Position", pos, motor)
|
||||
# Process wheel (base) commands.
|
||||
if "raw_velocity" in data:
|
||||
@@ -207,9 +201,7 @@ def run_lekiwi(robot_config):
|
||||
try:
|
||||
pos = motors_bus.read("Present_Position", motor)
|
||||
# Convert the position to a float (or use as is if already numeric).
|
||||
follower_arm_state.append(
|
||||
float(pos) if not isinstance(pos, (int, float)) else pos
|
||||
)
|
||||
follower_arm_state.append(float(pos) if not isinstance(pos, (int, float)) else pos)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Reading motor {motor} failed: {e}")
|
||||
|
||||
|
||||
@@ -285,9 +285,7 @@ class ManipulatorRobot:
|
||||
# to squeeze the gripper and have it spring back to an open position on its own.
|
||||
for name in self.leader_arms:
|
||||
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
|
||||
self.leader_arms[name].write(
|
||||
"Goal_Position", self.config.gripper_open_degree, "gripper"
|
||||
)
|
||||
self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper")
|
||||
|
||||
# Check both arms can be read
|
||||
for name in self.follower_arms:
|
||||
@@ -323,22 +321,16 @@ class ManipulatorRobot:
|
||||
run_arm_calibration,
|
||||
)
|
||||
|
||||
calibration = run_arm_calibration(
|
||||
arm, self.robot_type, name, arm_type
|
||||
)
|
||||
calibration = run_arm_calibration(arm, self.robot_type, name, arm_type)
|
||||
|
||||
elif self.robot_type in ["so100", "moss", "lekiwi"]:
|
||||
from lerobot.common.robot_devices.robots.feetech_calibration import (
|
||||
run_arm_manual_calibration,
|
||||
)
|
||||
|
||||
calibration = run_arm_manual_calibration(
|
||||
arm, self.robot_type, name, arm_type
|
||||
)
|
||||
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
|
||||
|
||||
print(
|
||||
f"Calibration is done! Saving calibration file '{arm_calib_path}'"
|
||||
)
|
||||
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
||||
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(arm_calib_path, "w") as f:
|
||||
json.dump(calibration, f)
|
||||
@@ -357,17 +349,13 @@ class ManipulatorRobot:
|
||||
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
|
||||
|
||||
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
||||
raise ValueError(
|
||||
"To run set robot preset, the torque must be disabled on all motors."
|
||||
)
|
||||
raise ValueError("To run set robot preset, the torque must be disabled on all motors.")
|
||||
|
||||
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't
|
||||
# rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm,
|
||||
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
|
||||
all_motors_except_gripper = [
|
||||
name for name in arm.motor_names if name != "gripper"
|
||||
]
|
||||
all_motors_except_gripper = [name for name in arm.motor_names if name != "gripper"]
|
||||
if len(all_motors_except_gripper) > 0:
|
||||
# 4 corresponds to Extended Position on Koch motors
|
||||
arm.write("Operating_Mode", 4, all_motors_except_gripper)
|
||||
@@ -396,9 +384,7 @@ class ManipulatorRobot:
|
||||
# Enable torque on the gripper of the leader arms, and move it to 45 degrees,
|
||||
# so that we can use it as a trigger to close the gripper of the follower arms.
|
||||
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
|
||||
self.leader_arms[name].write(
|
||||
"Goal_Position", self.config.gripper_open_degree, "gripper"
|
||||
)
|
||||
self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper")
|
||||
|
||||
def set_aloha_robot_preset(self):
|
||||
def set_shadow_(arm):
|
||||
@@ -428,15 +414,11 @@ class ManipulatorRobot:
|
||||
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
|
||||
all_motors_except_gripper = [
|
||||
name
|
||||
for name in self.follower_arms[name].motor_names
|
||||
if name != "gripper"
|
||||
name for name in self.follower_arms[name].motor_names if name != "gripper"
|
||||
]
|
||||
if len(all_motors_except_gripper) > 0:
|
||||
# 4 corresponds to Extended Position on Aloha motors
|
||||
self.follower_arms[name].write(
|
||||
"Operating_Mode", 4, all_motors_except_gripper
|
||||
)
|
||||
self.follower_arms[name].write("Operating_Mode", 4, all_motors_except_gripper)
|
||||
|
||||
# Use 'position control current based' for follower gripper to be limited by the limit of the current.
|
||||
# It can grasp an object without forcing too much even tho,
|
||||
@@ -484,9 +466,7 @@ class ManipulatorRobot:
|
||||
before_lread_t = time.perf_counter()
|
||||
leader_pos[name] = self.leader_arms[name].read("Present_Position")
|
||||
leader_pos[name] = torch.from_numpy(leader_pos[name])
|
||||
self.logs[f"read_leader_{name}_pos_dt_s"] = (
|
||||
time.perf_counter() - before_lread_t
|
||||
)
|
||||
self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t
|
||||
|
||||
# Send goal position to the follower
|
||||
follower_goal_pos = {}
|
||||
@@ -507,18 +487,14 @@ class ManipulatorRobot:
|
||||
if self.config.max_relative_target is not None:
|
||||
present_pos = self.follower_arms[name].read("Present_Position")
|
||||
present_pos = torch.from_numpy(present_pos)
|
||||
goal_pos = ensure_safe_goal_position(
|
||||
goal_pos, present_pos, self.config.max_relative_target
|
||||
)
|
||||
goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
|
||||
|
||||
# Used when record_data=True
|
||||
follower_goal_pos[name] = goal_pos
|
||||
|
||||
goal_pos = goal_pos.numpy().astype(np.float32)
|
||||
self.follower_arms[name].write("Goal_Position", goal_pos)
|
||||
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = (
|
||||
time.perf_counter() - before_fwrite_t
|
||||
)
|
||||
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t
|
||||
|
||||
# Early exit when recording data is not requested
|
||||
if not record_data:
|
||||
@@ -531,9 +507,7 @@ class ManipulatorRobot:
|
||||
before_fread_t = time.perf_counter()
|
||||
follower_pos[name] = self.follower_arms[name].read("Present_Position")
|
||||
follower_pos[name] = torch.from_numpy(follower_pos[name])
|
||||
self.logs[f"read_follower_{name}_pos_dt_s"] = (
|
||||
time.perf_counter() - before_fread_t
|
||||
)
|
||||
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
|
||||
|
||||
# Create state by concatenating follower current position
|
||||
state = []
|
||||
@@ -555,12 +529,8 @@ class ManipulatorRobot:
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
images[name] = torch.from_numpy(images[name])
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
|
||||
"delta_timestamp_s"
|
||||
]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = (
|
||||
time.perf_counter() - before_camread_t
|
||||
)
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
# Populate output dictionaries
|
||||
obs_dict, action_dict = {}, {}
|
||||
@@ -584,9 +554,7 @@ class ManipulatorRobot:
|
||||
before_fread_t = time.perf_counter()
|
||||
follower_pos[name] = self.follower_arms[name].read("Present_Position")
|
||||
follower_pos[name] = torch.from_numpy(follower_pos[name])
|
||||
self.logs[f"read_follower_{name}_pos_dt_s"] = (
|
||||
time.perf_counter() - before_fread_t
|
||||
)
|
||||
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
|
||||
|
||||
# Create state by concatenating follower current position
|
||||
state = []
|
||||
@@ -601,12 +569,8 @@ class ManipulatorRobot:
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
images[name] = torch.from_numpy(images[name])
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
|
||||
"delta_timestamp_s"
|
||||
]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = (
|
||||
time.perf_counter() - before_camread_t
|
||||
)
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
# Populate output dictionaries and format to pytorch
|
||||
obs_dict = {}
|
||||
@@ -652,9 +616,7 @@ class ManipulatorRobot:
|
||||
if self.config.max_relative_target is not None:
|
||||
present_pos = self.follower_arms[name].read("Present_Position")
|
||||
present_pos = torch.from_numpy(present_pos)
|
||||
goal_pos = ensure_safe_goal_position(
|
||||
goal_pos, present_pos, self.config.max_relative_target
|
||||
)
|
||||
goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
|
||||
|
||||
# Save tensor to concat and return
|
||||
action_sent.append(goal_pos)
|
||||
|
||||
@@ -271,9 +271,7 @@ class MobileManipulator:
|
||||
calibration = json.load(f)
|
||||
else:
|
||||
print(f"Missing calibration file '{arm_calib_path}'")
|
||||
calibration = run_arm_manual_calibration(
|
||||
arm, self.robot_type, name, arm_type
|
||||
)
|
||||
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
|
||||
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
||||
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(arm_calib_path, "w") as f:
|
||||
@@ -303,9 +301,7 @@ class MobileManipulator:
|
||||
bus.write("Torque_Enable", 0, motor_id)
|
||||
|
||||
# Then filter out wheels
|
||||
arm_only_dict = {
|
||||
k: v for k, v in bus.motors.items() if not k.startswith("wheel_")
|
||||
}
|
||||
arm_only_dict = {k: v for k, v in bus.motors.items() if not k.startswith("wheel_")}
|
||||
if not arm_only_dict:
|
||||
continue
|
||||
|
||||
@@ -377,9 +373,7 @@ class MobileManipulator:
|
||||
if new_arm_state is not None and frames is not None:
|
||||
self.last_frames = frames
|
||||
|
||||
remote_arm_state_tensor = torch.tensor(
|
||||
new_arm_state, dtype=torch.float32
|
||||
)
|
||||
remote_arm_state_tensor = torch.tensor(new_arm_state, dtype=torch.float32)
|
||||
self.last_remote_arm_state = remote_arm_state_tensor
|
||||
|
||||
present_speed = new_speed
|
||||
@@ -405,10 +399,7 @@ class MobileManipulator:
|
||||
def _process_present_speed(self, present_speed: dict) -> torch.Tensor:
|
||||
state_tensor = torch.zeros(3, dtype=torch.int32)
|
||||
if present_speed:
|
||||
decoded = {
|
||||
key: MobileManipulator.raw_to_degps(value)
|
||||
for key, value in present_speed.items()
|
||||
}
|
||||
decoded = {key: MobileManipulator.raw_to_degps(value) for key, value in present_speed.items()}
|
||||
if "1" in decoded:
|
||||
state_tensor[0] = decoded["1"]
|
||||
if "2" in decoded:
|
||||
@@ -421,9 +412,7 @@ class MobileManipulator:
|
||||
self, record_data: bool = False
|
||||
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
"MobileManipulator is not connected. Run `connect()` first."
|
||||
)
|
||||
raise RobotDeviceNotConnectedError("MobileManipulator is not connected. Run `connect()` first.")
|
||||
|
||||
speed_setting = self.speed_levels[self.speed_index]
|
||||
xy_speed = speed_setting["xy"] # e.g. 0.1, 0.25, or 0.4
|
||||
@@ -495,9 +484,7 @@ class MobileManipulator:
|
||||
body_state[2],
|
||||
) # Convert x,y to mm/s
|
||||
wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32)
|
||||
combined_state_tensor = torch.cat(
|
||||
(remote_arm_state_tensor, wheel_state_tensor), dim=0
|
||||
)
|
||||
combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0)
|
||||
|
||||
obs_dict = {"observation.state": combined_state_tensor}
|
||||
|
||||
|
||||
@@ -52,9 +52,7 @@ class StretchRobot(StretchAPI):
|
||||
def connect(self) -> None:
|
||||
self.is_connected = self.startup()
|
||||
if not self.is_connected:
|
||||
print(
|
||||
"Another process is already using Stretch. Try running 'stretch_free_robot_process.py'"
|
||||
)
|
||||
print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'")
|
||||
raise ConnectionError()
|
||||
|
||||
for name in self.cameras:
|
||||
@@ -62,9 +60,7 @@ class StretchRobot(StretchAPI):
|
||||
self.is_connected = self.is_connected and self.cameras[name].is_connected
|
||||
|
||||
if not self.is_connected:
|
||||
print(
|
||||
"Could not connect to the cameras, check that all cameras are plugged-in."
|
||||
)
|
||||
print("Could not connect to the cameras, check that all cameras are plugged-in.")
|
||||
raise ConnectionError()
|
||||
|
||||
self.run_calibration()
|
||||
@@ -109,12 +105,8 @@ class StretchRobot(StretchAPI):
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
images[name] = torch.from_numpy(images[name])
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
|
||||
"delta_timestamp_s"
|
||||
]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = (
|
||||
time.perf_counter() - before_camread_t
|
||||
)
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
# Populate output dictionaries
|
||||
obs_dict, action_dict = {}, {}
|
||||
@@ -158,12 +150,8 @@ class StretchRobot(StretchAPI):
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
images[name] = torch.from_numpy(images[name])
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
|
||||
"delta_timestamp_s"
|
||||
]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = (
|
||||
time.perf_counter() - before_camread_t
|
||||
)
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
# Populate output dictionaries
|
||||
obs_dict = {}
|
||||
|
||||
@@ -69,9 +69,7 @@ class HubMixin:
|
||||
if push_to_hub:
|
||||
if repo_id is None:
|
||||
repo_id = save_directory.name # Defaults to `save_directory` name
|
||||
return self.push_to_hub(
|
||||
repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs
|
||||
)
|
||||
return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs)
|
||||
return None
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
@@ -177,9 +175,7 @@ class HubMixin:
|
||||
The url of the commit of your object in the given repository.
|
||||
"""
|
||||
api = HfApi(token=token)
|
||||
repo_id = api.create_repo(
|
||||
repo_id=repo_id, private=private, exist_ok=True
|
||||
).repo_id
|
||||
repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
|
||||
|
||||
if commit_message is None:
|
||||
if "Policy" in self.__class__.__name__:
|
||||
|
||||
@@ -17,9 +17,7 @@ import importlib
|
||||
import logging
|
||||
|
||||
|
||||
def is_package_available(
|
||||
pkg_name: str, return_version: bool = False
|
||||
) -> tuple[bool, str] | bool:
|
||||
def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool:
|
||||
"""Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py
|
||||
Check if the package spec exists and grab its version to avoid importing a local directory.
|
||||
**Note:** this doesn't work for all packages.
|
||||
|
||||
@@ -20,16 +20,7 @@ from typing import TypeVar
|
||||
|
||||
import imageio
|
||||
|
||||
JsonLike = (
|
||||
str
|
||||
| int
|
||||
| float
|
||||
| bool
|
||||
| None
|
||||
| list["JsonLike"]
|
||||
| dict[str, "JsonLike"]
|
||||
| tuple["JsonLike", ...]
|
||||
)
|
||||
JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...]
|
||||
T = TypeVar("T", bound=JsonLike)
|
||||
|
||||
|
||||
@@ -85,9 +76,7 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
|
||||
|
||||
# Check length
|
||||
if len(target) != len(source):
|
||||
raise ValueError(
|
||||
f"List length mismatch: expected {len(target)}, got {len(source)}"
|
||||
)
|
||||
raise ValueError(f"List length mismatch: expected {len(target)}, got {len(source)}")
|
||||
|
||||
# Recursively update each element.
|
||||
for i in range(len(target)):
|
||||
@@ -99,14 +88,10 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
|
||||
# which we'll convert back to a tuple.
|
||||
elif isinstance(target, tuple):
|
||||
if not isinstance(source, list):
|
||||
raise TypeError(
|
||||
f"Type mismatch: expected list (for tuple), got {type(source)}"
|
||||
)
|
||||
raise TypeError(f"Type mismatch: expected list (for tuple), got {type(source)}")
|
||||
|
||||
if len(target) != len(source):
|
||||
raise ValueError(
|
||||
f"Tuple length mismatch: expected {len(target)}, got {len(source)}"
|
||||
)
|
||||
raise ValueError(f"Tuple length mismatch: expected {len(target)}, got {len(source)}")
|
||||
|
||||
# Convert each element, forming a new tuple.
|
||||
converted_items = []
|
||||
@@ -120,9 +105,7 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
|
||||
else:
|
||||
# Check the exact type. If these must match 1:1, do:
|
||||
if type(target) is not type(source):
|
||||
raise TypeError(
|
||||
f"Type mismatch: expected {type(target)}, got {type(source)}"
|
||||
)
|
||||
raise TypeError(f"Type mismatch: expected {type(target)}, got {type(source)}")
|
||||
return source
|
||||
|
||||
# Perform the in-place/recursive deserialization
|
||||
|
||||
@@ -107,17 +107,13 @@ class MetricsTracker:
|
||||
self.episodes = self.samples / self._avg_samples_per_ep
|
||||
self.epochs = self.samples / self._num_frames
|
||||
|
||||
def __getattr__(
|
||||
self, name: str
|
||||
) -> int | dict[str, AverageMeter] | AverageMeter | Any:
|
||||
def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any:
|
||||
if name in self.__dict__:
|
||||
return self.__dict__[name]
|
||||
elif name in self.metrics:
|
||||
return self.metrics[name]
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
||||
)
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name in self.__dict__:
|
||||
@@ -125,9 +121,7 @@ class MetricsTracker:
|
||||
elif name in self.metrics:
|
||||
self.metrics[name].update(value)
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
||||
)
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
||||
|
||||
def step(self) -> None:
|
||||
"""
|
||||
|
||||
@@ -123,9 +123,7 @@ def deserialize_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
py_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("py")}
|
||||
np_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("np")}
|
||||
torch_rng_state_dict = {
|
||||
k: v for k, v in rng_state_dict.items() if k.startswith("torch")
|
||||
}
|
||||
torch_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("torch")}
|
||||
|
||||
deserialize_python_rng_state(py_rng_state_dict)
|
||||
deserialize_numpy_rng_state(np_rng_state_dict)
|
||||
|
||||
@@ -48,9 +48,7 @@ def auto_select_torch_device() -> torch.device:
|
||||
logging.info("Metal backend detected, using cuda.")
|
||||
return torch.device("mps")
|
||||
else:
|
||||
logging.warning(
|
||||
"No accelerated backend detected. Using default cpu, this will be slow."
|
||||
)
|
||||
logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
@@ -98,9 +96,7 @@ def is_torch_device_available(try_device: str) -> bool:
|
||||
elif try_device == "cpu":
|
||||
return True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu."
|
||||
)
|
||||
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
|
||||
|
||||
|
||||
def is_amp_available(device: str):
|
||||
@@ -158,10 +154,7 @@ def _relative_path_between(path1: Path, path2: Path) -> Path:
|
||||
except ValueError: # most likely because path1 is not a subpath of path2
|
||||
common_parts = Path(osp.commonpath([path1, path2])).parts
|
||||
return Path(
|
||||
"/".join(
|
||||
[".."] * (len(path2.parts) - len(common_parts))
|
||||
+ list(path1.parts[len(common_parts) :])
|
||||
)
|
||||
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
|
||||
)
|
||||
|
||||
|
||||
@@ -172,26 +165,10 @@ def print_cuda_memory_usage():
|
||||
gc.collect()
|
||||
# Also clear the cache if you want to fully release the memory
|
||||
torch.cuda.empty_cache()
|
||||
print(
|
||||
"Current GPU Memory Allocated: {:.2f} MB".format(
|
||||
torch.cuda.memory_allocated(0) / 1024**2
|
||||
)
|
||||
)
|
||||
print(
|
||||
"Maximum GPU Memory Allocated: {:.2f} MB".format(
|
||||
torch.cuda.max_memory_allocated(0) / 1024**2
|
||||
)
|
||||
)
|
||||
print(
|
||||
"Current GPU Memory Reserved: {:.2f} MB".format(
|
||||
torch.cuda.memory_reserved(0) / 1024**2
|
||||
)
|
||||
)
|
||||
print(
|
||||
"Maximum GPU Memory Reserved: {:.2f} MB".format(
|
||||
torch.cuda.max_memory_reserved(0) / 1024**2
|
||||
)
|
||||
)
|
||||
print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2))
|
||||
print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2))
|
||||
print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2))
|
||||
print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2))
|
||||
|
||||
|
||||
def capture_timestamp_utc():
|
||||
@@ -223,9 +200,7 @@ def say(text, blocking=False):
|
||||
if blocking:
|
||||
subprocess.run(cmd, check=True)
|
||||
else:
|
||||
subprocess.Popen(
|
||||
cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0
|
||||
)
|
||||
subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0)
|
||||
|
||||
|
||||
def log_say(text, play_sounds, blocking=False):
|
||||
|
||||
@@ -26,9 +26,7 @@ from lerobot.common.constants import PRETRAINED_MODEL_DIR
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
def cfg_to_group(
|
||||
cfg: TrainPipelineConfig, return_list: bool = False
|
||||
) -> list[str] | str:
|
||||
def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
|
||||
"""Return a group name for logging. Optionally returns group name as list."""
|
||||
lst = [
|
||||
f"policy:{cfg.policy.type}",
|
||||
@@ -95,9 +93,7 @@ class WandBLogger:
|
||||
mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online",
|
||||
)
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(
|
||||
f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}"
|
||||
)
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
self._wandb = wandb
|
||||
|
||||
def log_policy(self, checkpoint_dir: Path):
|
||||
@@ -109,9 +105,7 @@ class WandBLogger:
|
||||
artifact_name = f"{self._group}-{step_id}"
|
||||
artifact_name = get_safe_wandb_artifact_name(artifact_name)
|
||||
artifact = self._wandb.Artifact(artifact_name, type="model")
|
||||
artifact.add_file(
|
||||
checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE
|
||||
)
|
||||
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
|
||||
self._wandb.log_artifact(artifact)
|
||||
|
||||
def log_dict(self, d: dict, step: int, mode: str = "train"):
|
||||
|
||||
Reference in New Issue
Block a user