[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
AdilZouitine
parent
76df8a31b3
commit
38f5fa4523
@@ -127,7 +127,9 @@ class AsyncImageWriter:
|
||||
self._stopped = False
|
||||
|
||||
if num_threads <= 0 and num_processes <= 0:
|
||||
raise ValueError("Number of threads and processes must be greater than zero.")
|
||||
raise ValueError(
|
||||
"Number of threads and processes must be greater than zero."
|
||||
)
|
||||
|
||||
if self.num_processes == 0:
|
||||
# Use threading
|
||||
@@ -141,12 +143,16 @@ class AsyncImageWriter:
|
||||
# Use multiprocessing
|
||||
self.queue = multiprocessing.JoinableQueue()
|
||||
for _ in range(self.num_processes):
|
||||
p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads))
|
||||
p = multiprocessing.Process(
|
||||
target=worker_process, args=(self.queue, self.num_threads)
|
||||
)
|
||||
p.daemon = True
|
||||
p.start()
|
||||
self.processes.append(p)
|
||||
|
||||
def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path):
|
||||
def save_image(
|
||||
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path
|
||||
):
|
||||
if isinstance(image, torch.Tensor):
|
||||
# Convert tensor to numpy array to minimize main process time
|
||||
image = image.cpu().numpy()
|
||||
|
||||
@@ -139,7 +139,9 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
||||
ep_chunk = self.get_episode_chunk(ep_index)
|
||||
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
||||
fpath = self.video_path.format(
|
||||
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index
|
||||
)
|
||||
return Path(fpath)
|
||||
|
||||
def get_episode_chunk(self, ep_index: int) -> int:
|
||||
@@ -183,7 +185,11 @@ class LeRobotDatasetMetadata:
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||
return [
|
||||
key
|
||||
for key, ft in self.features.items()
|
||||
if ft["dtype"] in ["video", "image"]
|
||||
]
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
@@ -285,7 +291,9 @@ class LeRobotDatasetMetadata:
|
||||
"""
|
||||
for key in self.video_keys:
|
||||
if not self.features[key].get("info", None):
|
||||
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
||||
video_path = self.root / self.get_video_file_path(
|
||||
ep_index=0, vid_key=key
|
||||
)
|
||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||
|
||||
def __repr__(self):
|
||||
@@ -619,7 +627,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
path = str(self.root / "data")
|
||||
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
|
||||
else:
|
||||
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
||||
files = [
|
||||
str(self.root / self.meta.get_data_file_path(ep_idx))
|
||||
for ep_idx in self.episodes
|
||||
]
|
||||
hf_dataset = load_dataset("parquet", data_files=files, split="train")
|
||||
|
||||
# TODO(aliberts): hf_dataset.set_format("torch")
|
||||
@@ -643,12 +654,20 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
"""Number of frames in selected episodes."""
|
||||
return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames
|
||||
return (
|
||||
len(self.hf_dataset)
|
||||
if self.hf_dataset is not None
|
||||
else self.meta.total_frames
|
||||
)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
"""Number of episodes selected."""
|
||||
return len(self.episodes) if self.episodes is not None else self.meta.total_episodes
|
||||
return (
|
||||
len(self.episodes)
|
||||
if self.episodes is not None
|
||||
else self.meta.total_episodes
|
||||
)
|
||||
|
||||
@property
|
||||
def features(self) -> dict[str, dict]:
|
||||
@@ -662,16 +681,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
return get_hf_features_from_features(self.features)
|
||||
|
||||
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
||||
def _get_query_indices(
|
||||
self, idx: int, ep_idx: int
|
||||
) -> tuple[dict[str, list[int | bool]]]:
|
||||
ep_start = self.episode_data_index["from"][ep_idx]
|
||||
ep_end = self.episode_data_index["to"][ep_idx]
|
||||
query_indices = {
|
||||
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
|
||||
key: [
|
||||
max(ep_start.item(), min(ep_end.item() - 1, idx + delta))
|
||||
for delta in delta_idx
|
||||
]
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
padding = { # Pad values outside of current episode range
|
||||
f"{key}_is_pad": torch.BoolTensor(
|
||||
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
|
||||
[
|
||||
(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item())
|
||||
for delta in delta_idx
|
||||
]
|
||||
)
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
@@ -771,13 +798,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
|
||||
return ep_buffer
|
||||
|
||||
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
def _get_image_file_path(
|
||||
self, episode_index: int, image_key: str, frame_index: int
|
||||
) -> Path:
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
return self.root / fpath
|
||||
|
||||
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
|
||||
def _save_image(
|
||||
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path
|
||||
) -> None:
|
||||
if self.image_writer is None:
|
||||
if isinstance(image, torch.Tensor):
|
||||
image = image.cpu().numpy()
|
||||
@@ -803,7 +834,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
# Automatically add frame_index and timestamp to episode buffer
|
||||
frame_index = self.episode_buffer["size"]
|
||||
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
|
||||
timestamp = (
|
||||
frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
|
||||
)
|
||||
self.episode_buffer["frame_index"].append(frame_index)
|
||||
self.episode_buffer["timestamp"].append(timestamp)
|
||||
|
||||
@@ -821,7 +854,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
if self.features[key]["dtype"] in ["image", "video"]:
|
||||
img_path = self._get_image_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
|
||||
episode_index=self.episode_buffer["episode_index"],
|
||||
image_key=key,
|
||||
frame_index=frame_index,
|
||||
)
|
||||
if frame_index == 0:
|
||||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
@@ -1132,7 +1167,13 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
def features(self) -> datasets.Features:
|
||||
features = {}
|
||||
for dataset in self._datasets:
|
||||
features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
|
||||
features.update(
|
||||
{
|
||||
k: v
|
||||
for k, v in dataset.hf_features.items()
|
||||
if k not in self.disabled_features
|
||||
}
|
||||
)
|
||||
return features
|
||||
|
||||
@property
|
||||
@@ -1193,7 +1234,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
continue
|
||||
break
|
||||
else:
|
||||
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
|
||||
raise AssertionError(
|
||||
"We expect the loop to break out as long as the index is within bounds."
|
||||
)
|
||||
item = self._datasets[dataset_idx][idx - start_idx]
|
||||
item["dataset_index"] = torch.tensor(dataset_idx)
|
||||
for data_key in self.disabled_features:
|
||||
|
||||
@@ -131,7 +131,9 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
else:
|
||||
self._delta_timestamps = None
|
||||
|
||||
def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]:
|
||||
def _make_data_spec(
|
||||
self, data_spec: dict[str, Any], buffer_capacity: int
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Makes the data spec for np.memmap."""
|
||||
if any(k.startswith("_") for k in data_spec):
|
||||
raise ValueError(
|
||||
@@ -154,14 +156,32 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()},
|
||||
# Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied
|
||||
# with real data rather than the dummy initialization.
|
||||
OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.OCCUPANCY_MASK_KEY: {
|
||||
"dtype": np.dtype("?"),
|
||||
"shape": (buffer_capacity,),
|
||||
},
|
||||
OnlineBuffer.INDEX_KEY: {
|
||||
"dtype": np.dtype("int64"),
|
||||
"shape": (buffer_capacity,),
|
||||
},
|
||||
OnlineBuffer.FRAME_INDEX_KEY: {
|
||||
"dtype": np.dtype("int64"),
|
||||
"shape": (buffer_capacity,),
|
||||
},
|
||||
OnlineBuffer.EPISODE_INDEX_KEY: {
|
||||
"dtype": np.dtype("int64"),
|
||||
"shape": (buffer_capacity,),
|
||||
},
|
||||
OnlineBuffer.TIMESTAMP_KEY: {
|
||||
"dtype": np.dtype("float64"),
|
||||
"shape": (buffer_capacity,),
|
||||
},
|
||||
}
|
||||
for k, v in data_spec.items():
|
||||
complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])}
|
||||
complete_data_spec[k] = {
|
||||
"dtype": v["dtype"],
|
||||
"shape": (buffer_capacity, *v["shape"]),
|
||||
}
|
||||
return complete_data_spec
|
||||
|
||||
def add_data(self, data: dict[str, np.ndarray]):
|
||||
@@ -188,7 +208,9 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
|
||||
# Shift the incoming indices if necessary.
|
||||
if self.num_frames > 0:
|
||||
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
|
||||
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][
|
||||
next_index - 1
|
||||
]
|
||||
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
|
||||
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
|
||||
data[OnlineBuffer.INDEX_KEY] += last_data_index + 1
|
||||
@@ -223,7 +245,11 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(
|
||||
np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
|
||||
np.unique(
|
||||
self._data[OnlineBuffer.EPISODE_INDEX_KEY][
|
||||
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -261,7 +287,9 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY],
|
||||
)
|
||||
)[0]
|
||||
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices]
|
||||
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][
|
||||
episode_data_indices
|
||||
]
|
||||
|
||||
for data_key in self.delta_timestamps:
|
||||
# Note: The logic in this loop is copied from `load_previous_and_future_frames`.
|
||||
@@ -278,7 +306,8 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
|
||||
# Check violated query timestamps are all outside the episode range.
|
||||
assert (
|
||||
(query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad])
|
||||
(query_ts[is_pad] < episode_timestamps[0])
|
||||
| (episode_timestamps[-1] < query_ts[is_pad])
|
||||
).all(), (
|
||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}"
|
||||
") inside the episode range."
|
||||
@@ -293,7 +322,9 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
|
||||
def get_data_by_key(self, key: str) -> torch.Tensor:
|
||||
"""Returns all data for a given data key as a Tensor."""
|
||||
return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
|
||||
return torch.from_numpy(
|
||||
self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]
|
||||
)
|
||||
|
||||
|
||||
def compute_sampler_weights(
|
||||
@@ -324,13 +355,19 @@ def compute_sampler_weights(
|
||||
- Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not
|
||||
included here to avoid adding complexity.
|
||||
"""
|
||||
if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0):
|
||||
raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.")
|
||||
if len(offline_dataset) == 0 and (
|
||||
online_dataset is None or len(online_dataset) == 0
|
||||
):
|
||||
raise ValueError(
|
||||
"At least one of `offline_dataset` or `online_dataset` should be contain data."
|
||||
)
|
||||
if (online_dataset is None) ^ (online_sampling_ratio is None):
|
||||
raise ValueError(
|
||||
"`online_dataset` and `online_sampling_ratio` must be provided together or not at all."
|
||||
)
|
||||
offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio
|
||||
offline_sampling_ratio = (
|
||||
0 if online_sampling_ratio is None else 1 - online_sampling_ratio
|
||||
)
|
||||
|
||||
weights = []
|
||||
|
||||
|
||||
@@ -45,7 +45,9 @@ def concatenate_episodes(ep_dicts):
|
||||
return data_dict
|
||||
|
||||
|
||||
def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4):
|
||||
def save_images_concurrently(
|
||||
imgs_array: numpy.array, out_dir: Path, max_workers: int = 4
|
||||
):
|
||||
out_dir = Path(out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -55,7 +57,10 @@ def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers
|
||||
|
||||
num_images = len(imgs_array)
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
|
||||
[
|
||||
executor.submit(save_image, imgs_array[i], i, out_dir)
|
||||
for i in range(num_images)
|
||||
]
|
||||
|
||||
|
||||
def get_default_encoding() -> dict:
|
||||
@@ -64,7 +69,8 @@ def get_default_encoding() -> dict:
|
||||
return {
|
||||
k: v.default
|
||||
for k, v in signature.parameters.items()
|
||||
if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"]
|
||||
if v.default is not inspect.Parameter.empty
|
||||
and k in ["vcodec", "pix_fmt", "g", "crf"]
|
||||
}
|
||||
|
||||
|
||||
@@ -77,7 +83,9 @@ def check_repo_id(repo_id: str) -> None:
|
||||
|
||||
|
||||
# TODO(aliberts): remove
|
||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
|
||||
def calculate_episode_data_index(
|
||||
hf_dataset: datasets.Dataset,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
|
||||
|
||||
|
||||
@@ -43,7 +43,10 @@ class EpisodeAwareSampler:
|
||||
):
|
||||
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
||||
indices.extend(
|
||||
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
|
||||
range(
|
||||
start_index.item() + drop_n_first_frames,
|
||||
end_index.item() - drop_n_last_frames,
|
||||
)
|
||||
)
|
||||
|
||||
self.indices = indices
|
||||
|
||||
@@ -58,7 +58,9 @@ class RandomSubsetApply(Transform):
|
||||
elif not isinstance(n_subset, int):
|
||||
raise TypeError("n_subset should be an int or None")
|
||||
elif not (1 <= n_subset <= len(transforms)):
|
||||
raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]")
|
||||
raise ValueError(
|
||||
f"n_subset should be in the interval [1, {len(transforms)}]"
|
||||
)
|
||||
|
||||
self.transforms = transforms
|
||||
total = sum(p)
|
||||
@@ -119,16 +121,22 @@ class SharpnessJitter(Transform):
|
||||
def _check_input(self, sharpness):
|
||||
if isinstance(sharpness, (int, float)):
|
||||
if sharpness < 0:
|
||||
raise ValueError("If sharpness is a single number, it must be non negative.")
|
||||
raise ValueError(
|
||||
"If sharpness is a single number, it must be non negative."
|
||||
)
|
||||
sharpness = [1.0 - sharpness, 1.0 + sharpness]
|
||||
sharpness[0] = max(sharpness[0], 0.0)
|
||||
elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2:
|
||||
sharpness = [float(v) for v in sharpness]
|
||||
else:
|
||||
raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.")
|
||||
raise TypeError(
|
||||
f"{sharpness=} should be a single number or a sequence with length 2."
|
||||
)
|
||||
|
||||
if not 0.0 <= sharpness[0] <= sharpness[1]:
|
||||
raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.")
|
||||
raise ValueError(
|
||||
f"sharpnesss values should be between (0., inf), but got {sharpness}."
|
||||
)
|
||||
|
||||
return float(sharpness[0]), float(sharpness[1])
|
||||
|
||||
|
||||
@@ -52,9 +52,15 @@ STATS_PATH = "meta/stats.json"
|
||||
EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
TASKS_PATH = "meta/tasks.jsonl"
|
||||
|
||||
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||
DEFAULT_VIDEO_PATH = (
|
||||
"videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
)
|
||||
DEFAULT_PARQUET_PATH = (
|
||||
"data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||
)
|
||||
DEFAULT_IMAGE_PATH = (
|
||||
"images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||
)
|
||||
|
||||
DATASET_CARD_TEMPLATE = """
|
||||
---
|
||||
@@ -540,7 +546,10 @@ def check_timestamps_sync(
|
||||
|
||||
|
||||
def check_delta_timestamps(
|
||||
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
|
||||
delta_timestamps: dict[str, list[float]],
|
||||
fps: int,
|
||||
tolerance_s: float,
|
||||
raise_value_error: bool = True,
|
||||
) -> bool:
|
||||
"""This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance.
|
||||
This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be
|
||||
@@ -548,10 +557,14 @@ def check_delta_timestamps(
|
||||
"""
|
||||
outside_tolerance = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts]
|
||||
within_tolerance = [
|
||||
abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts
|
||||
]
|
||||
if not all(within_tolerance):
|
||||
outside_tolerance[key] = [
|
||||
ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within
|
||||
ts
|
||||
for ts, is_within in zip(delta_ts, within_tolerance, strict=True)
|
||||
if not is_within
|
||||
]
|
||||
|
||||
if len(outside_tolerance) > 0:
|
||||
@@ -569,7 +582,9 @@ def check_delta_timestamps(
|
||||
return True
|
||||
|
||||
|
||||
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
|
||||
def get_delta_indices(
|
||||
delta_timestamps: dict[str, list[float]], fps: int
|
||||
) -> dict[str, list[int]]:
|
||||
delta_indices = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
delta_indices[key] = [round(d * fps) for d in delta_ts]
|
||||
@@ -634,7 +649,9 @@ def create_lerobot_dataset_card(
|
||||
],
|
||||
)
|
||||
|
||||
card_template = (importlib.resources.files("lerobot.common.datasets") / "card_template.md").read_text()
|
||||
card_template = (
|
||||
importlib.resources.files("lerobot.common.datasets") / "card_template.md"
|
||||
).read_text()
|
||||
|
||||
return DatasetCard.from_template(
|
||||
card_data=card_data,
|
||||
|
||||
@@ -118,7 +118,10 @@ DATASETS = {
|
||||
"single_task": "Place the battery into the slot of the remote controller.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_candy": {"single_task": "Pick up the candy and unwrap it.", **ALOHA_STATIC_INFO},
|
||||
"aloha_static_candy": {
|
||||
"single_task": "Pick up the candy and unwrap it.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_coffee": {
|
||||
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
@@ -167,13 +170,22 @@ DATASETS = {
|
||||
"single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO},
|
||||
"aloha_sim_insertion_scripted": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
|
||||
"aloha_static_ziploc_slide": {
|
||||
"single_task": "Slide open the ziploc bag.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_insertion_scripted": {
|
||||
"single_task": "Insert the peg into the socket.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_insertion_scripted_image": {
|
||||
"single_task": "Insert the peg into the socket.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_insertion_human": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
|
||||
"aloha_sim_insertion_human": {
|
||||
"single_task": "Insert the peg into the socket.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_insertion_human_image": {
|
||||
"single_task": "Insert the peg into the socket.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
@@ -194,10 +206,19 @@ DATASETS = {
|
||||
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"pusht": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
|
||||
"pusht_image": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
|
||||
"pusht": {
|
||||
"single_task": "Push the T-shaped block onto the T-shaped target.",
|
||||
**PUSHT_INFO,
|
||||
},
|
||||
"pusht_image": {
|
||||
"single_task": "Push the T-shaped block onto the T-shaped target.",
|
||||
**PUSHT_INFO,
|
||||
},
|
||||
"unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO},
|
||||
"unitreeh1_rearrange_objects": {"single_task": "Put the object into the bin.", **UNITREEH_INFO},
|
||||
"unitreeh1_rearrange_objects": {
|
||||
"single_task": "Put the object into the bin.",
|
||||
**UNITREEH_INFO,
|
||||
},
|
||||
"unitreeh1_two_robot_greeting": {
|
||||
"single_task": "Greet the other robot with a high five.",
|
||||
**UNITREEH_INFO,
|
||||
@@ -207,13 +228,31 @@ DATASETS = {
|
||||
**UNITREEH_INFO,
|
||||
},
|
||||
"xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_lift_medium_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_lift_medium_replay": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_lift_medium_replay_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_lift_medium_image": {
|
||||
"single_task": "Pick up the cube and lift it.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"xarm_lift_medium_replay": {
|
||||
"single_task": "Pick up the cube and lift it.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"xarm_lift_medium_replay_image": {
|
||||
"single_task": "Pick up the cube and lift it.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"xarm_push_medium_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"xarm_push_medium_replay": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"xarm_push_medium_replay_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"xarm_push_medium_image": {
|
||||
"single_task": "Push the cube onto the target.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"xarm_push_medium_replay": {
|
||||
"single_task": "Push the cube onto the target.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"xarm_push_medium_replay_image": {
|
||||
"single_task": "Push the cube onto the target.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"umi_cup_in_the_wild": {
|
||||
"single_task": "Put the cup on the plate.",
|
||||
"license": "apache-2.0",
|
||||
|
||||
@@ -218,7 +218,9 @@ def get_features_from_hf_dataset(
|
||||
dtype = ft.feature.dtype
|
||||
shape = (ft.length,)
|
||||
motor_names = (
|
||||
robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)]
|
||||
robot_config["names"][key]
|
||||
if robot_config
|
||||
else [f"motor_{i}" for i in range(ft.length)]
|
||||
)
|
||||
assert len(motor_names) == shape[0]
|
||||
names = {"motors": motor_names}
|
||||
@@ -242,11 +244,15 @@ def get_features_from_hf_dataset(
|
||||
return features
|
||||
|
||||
|
||||
def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
|
||||
def add_task_index_by_episodes(
|
||||
dataset: Dataset, tasks_by_episodes: dict
|
||||
) -> tuple[Dataset, list[str]]:
|
||||
df = dataset.to_pandas()
|
||||
tasks = list(set(tasks_by_episodes.values()))
|
||||
tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)}
|
||||
episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
episodes_to_task_index = {
|
||||
ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()
|
||||
}
|
||||
df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int)
|
||||
|
||||
features = dataset.features
|
||||
@@ -263,10 +269,19 @@ def add_task_index_from_tasks_col(
|
||||
# HACK: This is to clean some of the instructions in our version of Open X datasets
|
||||
prefix_to_clean = "tf.Tensor(b'"
|
||||
suffix_to_clean = "', shape=(), dtype=string)"
|
||||
df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean)
|
||||
df[tasks_col] = (
|
||||
df[tasks_col]
|
||||
.str.removeprefix(prefix_to_clean)
|
||||
.str.removesuffix(suffix_to_clean)
|
||||
)
|
||||
|
||||
# Create task_index col
|
||||
tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict()
|
||||
tasks_by_episode = (
|
||||
df.groupby("episode_index")[tasks_col]
|
||||
.unique()
|
||||
.apply(lambda x: x.tolist())
|
||||
.to_dict()
|
||||
)
|
||||
tasks = df[tasks_col].unique().tolist()
|
||||
tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)}
|
||||
df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int)
|
||||
@@ -291,7 +306,9 @@ def split_parquet_by_episodes(
|
||||
for ep_chunk in range(total_chunks):
|
||||
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
|
||||
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
||||
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
|
||||
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(
|
||||
episode_chunk=ep_chunk
|
||||
)
|
||||
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
||||
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
@@ -323,7 +340,9 @@ def move_videos(
|
||||
videos_moved = False
|
||||
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
|
||||
if len(video_files) == 0:
|
||||
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")]
|
||||
video_files = [
|
||||
str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")
|
||||
]
|
||||
videos_moved = True # Videos have already been moved
|
||||
|
||||
assert len(video_files) == total_episodes * len(video_keys)
|
||||
@@ -354,7 +373,9 @@ def move_videos(
|
||||
target_path = DEFAULT_VIDEO_PATH.format(
|
||||
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
|
||||
)
|
||||
video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
|
||||
video_file = V1_VIDEO_FILE.format(
|
||||
video_key=vid_key, episode_index=ep_idx
|
||||
)
|
||||
if len(video_dirs) == 1:
|
||||
video_path = video_dirs[0] / video_file
|
||||
else:
|
||||
@@ -371,7 +392,9 @@ def move_videos(
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None:
|
||||
def fix_lfs_video_files_tracking(
|
||||
work_dir: Path, lfs_untracked_videos: list[str]
|
||||
) -> None:
|
||||
"""
|
||||
HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
|
||||
there's no other option than to download the actual files and reupload them with lfs tracking.
|
||||
@@ -379,7 +402,12 @@ def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]
|
||||
for i in range(0, len(lfs_untracked_videos), 100):
|
||||
files = lfs_untracked_videos[i : i + 100]
|
||||
try:
|
||||
subprocess.run(["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True)
|
||||
subprocess.run(
|
||||
["git", "rm", "--cached", *files],
|
||||
cwd=work_dir,
|
||||
capture_output=True,
|
||||
check=True,
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print("git rm --cached ERROR:")
|
||||
print(e.stderr)
|
||||
@@ -390,10 +418,14 @@ def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None:
|
||||
def fix_gitattributes(
|
||||
work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path
|
||||
) -> None:
|
||||
shutil.copyfile(clean_gittatributes, current_gittatributes)
|
||||
subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
|
||||
subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True)
|
||||
subprocess.run(
|
||||
["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True
|
||||
)
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
@@ -402,7 +434,17 @@ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
|
||||
repo_url = f"https://huggingface.co/datasets/{repo_id}"
|
||||
env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files
|
||||
subprocess.run(
|
||||
["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)],
|
||||
[
|
||||
"git",
|
||||
"clone",
|
||||
"--branch",
|
||||
branch,
|
||||
"--single-branch",
|
||||
"--depth",
|
||||
"1",
|
||||
repo_url,
|
||||
str(work_dir),
|
||||
],
|
||||
check=True,
|
||||
env=env,
|
||||
)
|
||||
@@ -410,13 +452,19 @@ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
|
||||
|
||||
def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]:
|
||||
lfs_tracked_files = subprocess.run(
|
||||
["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True
|
||||
["git", "lfs", "ls-files", "-n"],
|
||||
cwd=work_dir,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
|
||||
return [f for f in video_files if f not in lfs_tracked_files]
|
||||
|
||||
|
||||
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
|
||||
def get_videos_info(
|
||||
repo_id: str, local_dir: Path, video_keys: list[str], branch: str
|
||||
) -> dict:
|
||||
# Assumes first episode
|
||||
video_files = [
|
||||
DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
|
||||
@@ -424,7 +472,11 @@ def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch
|
||||
]
|
||||
hub_api = HfApi()
|
||||
hub_api.snapshot_download(
|
||||
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
local_dir=local_dir,
|
||||
revision=branch,
|
||||
allow_patterns=video_files,
|
||||
)
|
||||
videos_info_dict = {}
|
||||
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
|
||||
@@ -451,7 +503,11 @@ def convert_dataset(
|
||||
|
||||
hub_api = HfApi()
|
||||
hub_api.snapshot_download(
|
||||
repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/"
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
revision=v1,
|
||||
local_dir=v1x_dir,
|
||||
ignore_patterns="videos*/",
|
||||
)
|
||||
branch = "main"
|
||||
if test_branch:
|
||||
@@ -483,19 +539,31 @@ def convert_dataset(
|
||||
if single_task:
|
||||
tasks_by_episodes = {ep_idx: single_task for ep_idx in episode_indices}
|
||||
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
||||
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
tasks_by_episodes = {
|
||||
ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()
|
||||
}
|
||||
elif tasks_path:
|
||||
tasks_by_episodes = load_json(tasks_path)
|
||||
tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
|
||||
tasks_by_episodes = {
|
||||
int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()
|
||||
}
|
||||
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
||||
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
tasks_by_episodes = {
|
||||
ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()
|
||||
}
|
||||
elif tasks_col:
|
||||
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col)
|
||||
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(
|
||||
dataset, tasks_col
|
||||
)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
|
||||
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
|
||||
assert set(tasks) == {
|
||||
task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks
|
||||
}
|
||||
tasks = [
|
||||
{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)
|
||||
]
|
||||
write_jsonlines(tasks, v20_dir / TASKS_PATH)
|
||||
features["task_index"] = {
|
||||
"dtype": "int64",
|
||||
@@ -509,14 +577,25 @@ def convert_dataset(
|
||||
dataset = dataset.remove_columns(video_keys)
|
||||
clean_gitattr = Path(
|
||||
hub_api.hf_hub_download(
|
||||
repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes"
|
||||
repo_id=GITATTRIBUTES_REF,
|
||||
repo_type="dataset",
|
||||
local_dir=local_dir,
|
||||
filename=".gitattributes",
|
||||
)
|
||||
).absolute()
|
||||
with tempfile.TemporaryDirectory() as tmp_video_dir:
|
||||
move_videos(
|
||||
repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch
|
||||
repo_id,
|
||||
video_keys,
|
||||
total_episodes,
|
||||
total_chunks,
|
||||
Path(tmp_video_dir),
|
||||
clean_gitattr,
|
||||
branch,
|
||||
)
|
||||
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
|
||||
videos_info = get_videos_info(
|
||||
repo_id, v1x_dir, video_keys=video_keys, branch=branch
|
||||
)
|
||||
for key in video_keys:
|
||||
features[key]["shape"] = (
|
||||
videos_info[key].pop("video.height"),
|
||||
@@ -524,15 +603,22 @@ def convert_dataset(
|
||||
videos_info[key].pop("video.channels"),
|
||||
)
|
||||
features[key]["video_info"] = videos_info[key]
|
||||
assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
|
||||
assert math.isclose(
|
||||
videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3
|
||||
)
|
||||
if "encoding" in metadata_v1:
|
||||
assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
|
||||
assert (
|
||||
videos_info[key]["video.pix_fmt"]
|
||||
== metadata_v1["encoding"]["pix_fmt"]
|
||||
)
|
||||
else:
|
||||
assert metadata_v1.get("video", 0) == 0
|
||||
videos_info = None
|
||||
|
||||
# Split data into 1 parquet file by episode
|
||||
episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
|
||||
episode_lengths = split_parquet_by_episodes(
|
||||
dataset, total_episodes, total_chunks, v20_dir
|
||||
)
|
||||
|
||||
if robot_config is not None:
|
||||
robot_type = robot_config.type
|
||||
@@ -543,7 +629,11 @@ def convert_dataset(
|
||||
|
||||
# Episodes
|
||||
episodes = [
|
||||
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
|
||||
{
|
||||
"episode_index": ep_idx,
|
||||
"tasks": tasks_by_episodes[ep_idx],
|
||||
"length": episode_lengths[ep_idx],
|
||||
}
|
||||
for ep_idx in episode_indices
|
||||
]
|
||||
write_jsonlines(episodes, v20_dir / EPISODES_PATH)
|
||||
@@ -566,16 +656,27 @@ def convert_dataset(
|
||||
}
|
||||
write_json(metadata_v2_0, v20_dir / INFO_PATH)
|
||||
convert_stats_to_json(v1x_dir, v20_dir)
|
||||
card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs)
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs
|
||||
)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
|
||||
hub_api.delete_folder(
|
||||
repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch
|
||||
)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
|
||||
hub_api.delete_folder(
|
||||
repo_id=repo_id,
|
||||
path_in_repo="meta_data",
|
||||
repo_type="dataset",
|
||||
revision=branch,
|
||||
)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
|
||||
hub_api.delete_folder(
|
||||
repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch
|
||||
)
|
||||
|
||||
hub_api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
|
||||
@@ -344,7 +344,9 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
result = subprocess.run(
|
||||
ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
@@ -358,7 +360,9 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
"has_audio": True,
|
||||
"audio.channels": audio_stream_info.get("channels", None),
|
||||
"audio.codec": audio_stream_info.get("codec_name", None),
|
||||
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
|
||||
"audio.bit_rate": int(audio_stream_info["bit_rate"])
|
||||
if audio_stream_info.get("bit_rate")
|
||||
else None,
|
||||
"audio.sample_rate": int(audio_stream_info["sample_rate"])
|
||||
if audio_stream_info.get("sample_rate")
|
||||
else None,
|
||||
@@ -380,7 +384,9 @@ def get_video_info(video_path: Path | str) -> dict:
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
result = subprocess.run(
|
||||
ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
|
||||
@@ -70,7 +70,9 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
|
||||
return env
|
||||
|
||||
|
||||
def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
|
||||
def make_maniskill_env(
|
||||
cfg: DictConfig, n_envs: int | None = None
|
||||
) -> gym.vector.VectorEnv | None:
|
||||
"""Make ManiSkill3 gym environment"""
|
||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
|
||||
@@ -87,7 +89,9 @@ def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector
|
||||
# state should have the size of 25
|
||||
# env = ConvertToLeRobotEnv(env, n_envs)
|
||||
# env = PixelWrapper(cfg, env, n_envs)
|
||||
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
|
||||
env._max_episode_steps = env.max_episode_steps = (
|
||||
50 # gym_utils.find_max_episode_steps_value(env)
|
||||
)
|
||||
env.unwrapped.metadata["render_fps"] = 20
|
||||
|
||||
return env
|
||||
@@ -114,7 +118,11 @@ class PixelWrapper(gym.Wrapper):
|
||||
def _get_obs(self, obs):
|
||||
frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2)
|
||||
self._frames.append(frame)
|
||||
return {"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(self.env.device)}
|
||||
return {
|
||||
"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(
|
||||
self.env.device
|
||||
)
|
||||
}
|
||||
|
||||
def reset(self, seed):
|
||||
obs, info = self.env.reset() # (seed=seed)
|
||||
@@ -148,7 +156,9 @@ class ConvertToLeRobotEnv(gym.Wrapper):
|
||||
|
||||
images = torch.concat(images, axis=-1)
|
||||
# flatten the rest of the data which should just be state data
|
||||
observation = common.flatten_state_dict(observation, use_torch=True, device=self.base_env.device)
|
||||
observation = common.flatten_state_dict(
|
||||
observation, use_torch=True, device=self.base_env.device
|
||||
)
|
||||
ret = dict()
|
||||
ret["state"] = observation
|
||||
ret["pixels"] = images
|
||||
|
||||
@@ -84,7 +84,9 @@ class Logger:
|
||||
pretrained_model_dir_name = "pretrained_model"
|
||||
training_state_file_name = "training_state.pth"
|
||||
|
||||
def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None):
|
||||
def __init__(
|
||||
self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
log_dir: The directory to save all logs and training outputs to.
|
||||
@@ -104,7 +106,9 @@ class Logger:
|
||||
enable_wandb = cfg.get("wandb", {}).get("enable", False)
|
||||
run_offline = not enable_wandb or not project
|
||||
if run_offline:
|
||||
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
logging.info(
|
||||
colored("Logs will be saved locally.", "yellow", attrs=["bold"])
|
||||
)
|
||||
self._wandb = None
|
||||
else:
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
@@ -130,7 +134,9 @@ class Logger:
|
||||
# Handle custom step key for rl asynchronous training.
|
||||
self._wandb_custom_step_key: set[str] | None = None
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
logging.info(
|
||||
f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}"
|
||||
)
|
||||
self._wandb = wandb
|
||||
|
||||
@classmethod
|
||||
@@ -151,7 +157,9 @@ class Logger:
|
||||
"""
|
||||
return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name
|
||||
|
||||
def save_model(self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None):
|
||||
def save_model(
|
||||
self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None
|
||||
):
|
||||
"""Save the weights of the Policy model using PyTorchModelHubMixin.
|
||||
|
||||
The weights are saved in a folder called "pretrained_model" under the checkpoint directory.
|
||||
@@ -221,22 +229,30 @@ class Logger:
|
||||
else f"{self._group.replace(':', '_').replace('/', '_')}-{self._cfg.seed}-{identifier}"
|
||||
)
|
||||
self.save_model(
|
||||
checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name
|
||||
checkpoint_dir / self.pretrained_model_dir_name,
|
||||
policy,
|
||||
wandb_artifact_name=wandb_artifact_name,
|
||||
)
|
||||
self.save_training_state(
|
||||
checkpoint_dir, train_step, optimizer, scheduler, interaction_step
|
||||
)
|
||||
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler, interaction_step)
|
||||
os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir)
|
||||
|
||||
def load_last_training_state(self, optimizer: Optimizer | dict, scheduler: LRScheduler | None) -> int:
|
||||
def load_last_training_state(
|
||||
self, optimizer: Optimizer | dict, scheduler: LRScheduler | None
|
||||
) -> int:
|
||||
"""
|
||||
Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and
|
||||
random state, and return the global training step.
|
||||
"""
|
||||
training_state = torch.load(self.last_checkpoint_dir / self.training_state_file_name)
|
||||
training_state = torch.load(
|
||||
self.last_checkpoint_dir / self.training_state_file_name
|
||||
)
|
||||
# For the case where the optimizer is a dictionary of optimizers (e.g., sac)
|
||||
if type(training_state["optimizer"]) is dict:
|
||||
assert set(training_state["optimizer"].keys()) == set(optimizer.keys()), (
|
||||
"Optimizer dictionaries do not have the same keys during resume!"
|
||||
)
|
||||
assert set(training_state["optimizer"].keys()) == set(
|
||||
optimizer.keys()
|
||||
), "Optimizer dictionaries do not have the same keys during resume!"
|
||||
for k, v in training_state["optimizer"].items():
|
||||
optimizer[k].load_state_dict(v)
|
||||
else:
|
||||
@@ -248,10 +264,18 @@ class Logger:
|
||||
"The checkpoint contains a scheduler state_dict, but no LRScheduler was provided."
|
||||
)
|
||||
# Small hack to get the expected keys: use `get_global_random_state`.
|
||||
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
|
||||
set_global_random_state(
|
||||
{k: training_state[k] for k in get_global_random_state()}
|
||||
)
|
||||
return training_state["step"]
|
||||
|
||||
def log_dict(self, d, step: int | None = None, mode="train", custom_step_key: str | None = None):
|
||||
def log_dict(
|
||||
self,
|
||||
d,
|
||||
step: int | None = None,
|
||||
mode="train",
|
||||
custom_step_key: str | None = None,
|
||||
):
|
||||
"""Log a dictionary of metrics to WandB."""
|
||||
assert mode in {"train", "eval"}
|
||||
# TODO(alexander-soare): Add local text log.
|
||||
@@ -280,12 +304,20 @@ class Logger:
|
||||
continue
|
||||
|
||||
# Do not log the custom step key itself.
|
||||
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
|
||||
if (
|
||||
self._wandb_custom_step_key is not None
|
||||
and k in self._wandb_custom_step_key
|
||||
):
|
||||
continue
|
||||
|
||||
if custom_step_key is not None:
|
||||
value_custom_step = d[custom_step_key]
|
||||
self._wandb.log({f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step})
|
||||
self._wandb.log(
|
||||
{
|
||||
f"{mode}/{k}": v,
|
||||
f"{mode}/{custom_step_key}": value_custom_step,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
|
||||
|
||||
@@ -74,7 +74,9 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
self.model = ACT(config)
|
||||
|
||||
if config.temporal_ensemble_coeff is not None:
|
||||
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
|
||||
self.temporal_ensembler = ACTTemporalEnsembler(
|
||||
config.temporal_ensemble_coeff, config.chunk_size
|
||||
)
|
||||
|
||||
self.reset()
|
||||
|
||||
@@ -153,7 +155,8 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
l1_loss = (
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none")
|
||||
* ~batch["action_is_pad"].unsqueeze(-1)
|
||||
).mean()
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss.item()}
|
||||
@@ -163,7 +166,12 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
# KL-divergence per batch element, then take the mean over the batch.
|
||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||
mean_kld = (
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
(
|
||||
-0.5
|
||||
* (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())
|
||||
)
|
||||
.sum(-1)
|
||||
.mean()
|
||||
)
|
||||
loss_dict["kld_loss"] = mean_kld.item()
|
||||
loss = l1_loss + mean_kld * self.config.kl_weight
|
||||
@@ -217,7 +225,9 @@ class ACTTemporalEnsembler:
|
||||
```
|
||||
"""
|
||||
self.chunk_size = chunk_size
|
||||
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
|
||||
self.ensemble_weights = torch.exp(
|
||||
-temporal_ensemble_coeff * torch.arange(chunk_size)
|
||||
)
|
||||
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
|
||||
self.reset()
|
||||
|
||||
@@ -233,7 +243,9 @@ class ACTTemporalEnsembler:
|
||||
time steps, and pop/return the next batch of actions in the sequence.
|
||||
"""
|
||||
self.ensemble_weights = self.ensemble_weights.to(device=actions.device)
|
||||
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device)
|
||||
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(
|
||||
device=actions.device
|
||||
)
|
||||
if self.ensembled_actions is None:
|
||||
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
|
||||
# time step of the episode.
|
||||
@@ -241,19 +253,34 @@ class ACTTemporalEnsembler:
|
||||
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
|
||||
# operations later.
|
||||
self.ensembled_actions_count = torch.ones(
|
||||
(self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
|
||||
(self.chunk_size, 1),
|
||||
dtype=torch.long,
|
||||
device=self.ensembled_actions.device,
|
||||
)
|
||||
else:
|
||||
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
|
||||
# the online update for those entries.
|
||||
self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1]
|
||||
self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
|
||||
self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count]
|
||||
self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size)
|
||||
self.ensembled_actions *= self.ensemble_weights_cumsum[
|
||||
self.ensembled_actions_count - 1
|
||||
]
|
||||
self.ensembled_actions += (
|
||||
actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
|
||||
)
|
||||
self.ensembled_actions /= self.ensemble_weights_cumsum[
|
||||
self.ensembled_actions_count
|
||||
]
|
||||
self.ensembled_actions_count = torch.clamp(
|
||||
self.ensembled_actions_count + 1, max=self.chunk_size
|
||||
)
|
||||
# The last action, which has no prior online average, needs to get concatenated onto the end.
|
||||
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
|
||||
self.ensembled_actions = torch.cat(
|
||||
[self.ensembled_actions, actions[:, -1:]], dim=1
|
||||
)
|
||||
self.ensembled_actions_count = torch.cat(
|
||||
[self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])]
|
||||
[
|
||||
self.ensembled_actions_count,
|
||||
torch.ones_like(self.ensembled_actions_count[-1:]),
|
||||
]
|
||||
)
|
||||
# "Consume" the first action.
|
||||
action, self.ensembled_actions, self.ensembled_actions_count = (
|
||||
@@ -319,7 +346,9 @@ class ACT(nn.Module):
|
||||
config.dim_model,
|
||||
)
|
||||
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
||||
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
|
||||
self.vae_encoder_latent_output_proj = nn.Linear(
|
||||
config.dim_model, config.latent_dim * 2
|
||||
)
|
||||
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
|
||||
# dimension.
|
||||
num_input_token_encoder = 1 + config.chunk_size
|
||||
@@ -327,20 +356,28 @@ class ACT(nn.Module):
|
||||
num_input_token_encoder += 1
|
||||
self.register_buffer(
|
||||
"vae_encoder_pos_enc",
|
||||
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
|
||||
create_sinusoidal_pos_embedding(
|
||||
num_input_token_encoder, config.dim_model
|
||||
).unsqueeze(0),
|
||||
)
|
||||
|
||||
# Backbone for image feature extraction.
|
||||
if self.config.image_features:
|
||||
backbone_model = getattr(torchvision.models, config.vision_backbone)(
|
||||
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
|
||||
replace_stride_with_dilation=[
|
||||
False,
|
||||
False,
|
||||
config.replace_final_stride_with_dilation,
|
||||
],
|
||||
weights=config.pretrained_backbone_weights,
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
|
||||
# feature map).
|
||||
# Note: The forward method of this returns a dict: {"feature_map": output}.
|
||||
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
|
||||
self.backbone = IntermediateLayerGetter(
|
||||
backbone_model, return_layers={"layer4": "feature_map"}
|
||||
)
|
||||
|
||||
# Transformer (acts as VAE decoder when training with the variational objective).
|
||||
self.encoder = ACTEncoder(config)
|
||||
@@ -386,7 +423,9 @@ class ACT(nn.Module):
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
|
||||
def forward(
|
||||
self, batch: dict[str, Tensor]
|
||||
) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
|
||||
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
|
||||
|
||||
`batch` should have the following structure:
|
||||
@@ -424,7 +463,9 @@ class ACT(nn.Module):
|
||||
if self.config.robot_state_feature:
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
||||
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(
|
||||
batch["action"]
|
||||
) # (B, S, D)
|
||||
|
||||
if self.config.robot_state_feature:
|
||||
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
||||
@@ -465,20 +506,24 @@ class ACT(nn.Module):
|
||||
# When not using the VAE encoder, we set the latent to be all zeros.
|
||||
mu = log_sigma_x2 = None
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
|
||||
batch["observation.state"].device
|
||||
)
|
||||
latent_sample = torch.zeros(
|
||||
[batch_size, self.config.latent_dim], dtype=torch.float32
|
||||
).to(batch["observation.state"].device)
|
||||
|
||||
# Prepare transformer encoder inputs.
|
||||
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
|
||||
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
|
||||
encoder_in_pos_embed = list(
|
||||
self.encoder_1d_feature_pos_embed.weight.unsqueeze(1)
|
||||
)
|
||||
# Robot state token.
|
||||
if self.config.robot_state_feature:
|
||||
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
|
||||
# Environment state token.
|
||||
if self.config.env_state_feature:
|
||||
encoder_in_tokens.append(
|
||||
self.encoder_env_state_input_proj(batch["observation.environment_state"])
|
||||
self.encoder_env_state_input_proj(
|
||||
batch["observation.environment_state"]
|
||||
)
|
||||
)
|
||||
|
||||
# Camera observation features and positional embeddings.
|
||||
@@ -535,12 +580,21 @@ class ACTEncoder(nn.Module):
|
||||
def __init__(self, config: ACTConfig, is_vae_encoder: bool = False):
|
||||
super().__init__()
|
||||
self.is_vae_encoder = is_vae_encoder
|
||||
num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers
|
||||
self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)])
|
||||
num_layers = (
|
||||
config.n_vae_encoder_layers
|
||||
if self.is_vae_encoder
|
||||
else config.n_encoder_layers
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[ACTEncoderLayer(config) for _ in range(num_layers)]
|
||||
)
|
||||
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
|
||||
|
||||
def forward(
|
||||
self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
|
||||
self,
|
||||
x: Tensor,
|
||||
pos_embed: Tensor | None = None,
|
||||
key_padding_mask: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
for layer in self.layers:
|
||||
x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
|
||||
@@ -551,7 +605,9 @@ class ACTEncoder(nn.Module):
|
||||
class ACTEncoderLayer(nn.Module):
|
||||
def __init__(self, config: ACTConfig):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
|
||||
self.self_attn = nn.MultiheadAttention(
|
||||
config.dim_model, config.n_heads, dropout=config.dropout
|
||||
)
|
||||
|
||||
# Feed forward layers.
|
||||
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
|
||||
@@ -566,7 +622,9 @@ class ACTEncoderLayer(nn.Module):
|
||||
self.activation = get_activation_fn(config.feedforward_activation)
|
||||
self.pre_norm = config.pre_norm
|
||||
|
||||
def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor:
|
||||
def forward(
|
||||
self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
|
||||
) -> Tensor:
|
||||
skip = x
|
||||
if self.pre_norm:
|
||||
x = self.norm1(x)
|
||||
@@ -591,7 +649,9 @@ class ACTDecoder(nn.Module):
|
||||
def __init__(self, config: ACTConfig):
|
||||
"""Convenience module for running multiple decoder layers followed by normalization."""
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)])
|
||||
self.layers = nn.ModuleList(
|
||||
[ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)]
|
||||
)
|
||||
self.norm = nn.LayerNorm(config.dim_model)
|
||||
|
||||
def forward(
|
||||
@@ -603,7 +663,10 @@ class ACTDecoder(nn.Module):
|
||||
) -> Tensor:
|
||||
for layer in self.layers:
|
||||
x = layer(
|
||||
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
|
||||
x,
|
||||
encoder_out,
|
||||
decoder_pos_embed=decoder_pos_embed,
|
||||
encoder_pos_embed=encoder_pos_embed,
|
||||
)
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
@@ -613,8 +676,12 @@ class ACTDecoder(nn.Module):
|
||||
class ACTDecoderLayer(nn.Module):
|
||||
def __init__(self, config: ACTConfig):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
|
||||
self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
|
||||
self.self_attn = nn.MultiheadAttention(
|
||||
config.dim_model, config.n_heads, dropout=config.dropout
|
||||
)
|
||||
self.multihead_attn = nn.MultiheadAttention(
|
||||
config.dim_model, config.n_heads, dropout=config.dropout
|
||||
)
|
||||
|
||||
# Feed forward layers.
|
||||
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
|
||||
@@ -655,7 +722,9 @@ class ACTDecoderLayer(nn.Module):
|
||||
if self.pre_norm:
|
||||
x = self.norm1(x)
|
||||
q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
|
||||
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
|
||||
x = self.self_attn(q, k, value=x)[
|
||||
0
|
||||
] # select just the output, not the attention weights
|
||||
x = skip + self.dropout1(x)
|
||||
if self.pre_norm:
|
||||
skip = x
|
||||
@@ -692,9 +761,14 @@ def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tenso
|
||||
"""
|
||||
|
||||
def get_position_angle_vec(position):
|
||||
return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)]
|
||||
return [
|
||||
position / np.power(10000, 2 * (hid_j // 2) / dimension)
|
||||
for hid_j in range(dimension)
|
||||
]
|
||||
|
||||
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)])
|
||||
sinusoid_table = np.array(
|
||||
[get_position_angle_vec(pos_i) for pos_i in range(num_positions)]
|
||||
)
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
return torch.from_numpy(sinusoid_table).float()
|
||||
@@ -739,7 +813,9 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
|
||||
x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
|
||||
|
||||
inverse_frequency = self._temperature ** (
|
||||
2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
|
||||
2
|
||||
* (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2)
|
||||
/ self.dimension
|
||||
)
|
||||
|
||||
x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
|
||||
@@ -747,9 +823,15 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
|
||||
|
||||
# Note: this stack then flatten operation results in interleaved sine and cosine terms.
|
||||
# pos_embed_x and pos_embed_y are (1, H, W, C // 2).
|
||||
pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3)
|
||||
pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3)
|
||||
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W)
|
||||
pos_embed_x = torch.stack(
|
||||
(x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1
|
||||
).flatten(3)
|
||||
pos_embed_y = torch.stack(
|
||||
(y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1
|
||||
).flatten(3)
|
||||
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(
|
||||
0, 3, 1, 2
|
||||
) # (1, C, H, W)
|
||||
|
||||
return pos_embed
|
||||
|
||||
|
||||
@@ -132,7 +132,11 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
# stack n latest observations from the queue
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
batch = {
|
||||
k: torch.stack(list(self._queues[k]), dim=1)
|
||||
for k in batch
|
||||
if k in self._queues
|
||||
}
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
|
||||
# TODO(rcadene): make above methods return output dictionary?
|
||||
@@ -189,7 +193,9 @@ class DiffusionModel(nn.Module):
|
||||
if self.config.env_state_feature:
|
||||
global_cond_dim += self.config.env_state_feature.shape[0]
|
||||
|
||||
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
|
||||
self.unet = DiffusionConditionalUnet1d(
|
||||
config, global_cond_dim=global_cond_dim * config.n_obs_steps
|
||||
)
|
||||
|
||||
self.noise_scheduler = _make_noise_scheduler(
|
||||
config.noise_scheduler_type,
|
||||
@@ -209,7 +215,10 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
# ========= inference ============
|
||||
def conditional_sample(
|
||||
self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None
|
||||
self,
|
||||
batch_size: int,
|
||||
global_cond: Tensor | None = None,
|
||||
generator: torch.Generator | None = None,
|
||||
) -> Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
dtype = get_dtype_from_parameters(self)
|
||||
@@ -232,7 +241,9 @@ class DiffusionModel(nn.Module):
|
||||
global_cond=global_cond,
|
||||
)
|
||||
# Compute previous image: x_t -> x_t-1
|
||||
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
|
||||
sample = self.noise_scheduler.step(
|
||||
model_output, t, sample, generator=generator
|
||||
).prev_sample
|
||||
|
||||
return sample
|
||||
|
||||
@@ -244,27 +255,39 @@ class DiffusionModel(nn.Module):
|
||||
if self.config.image_features:
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
|
||||
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
|
||||
images_per_camera = einops.rearrange(
|
||||
batch["observation.images"], "b s n ... -> n (b s) ..."
|
||||
)
|
||||
img_features_list = torch.cat(
|
||||
[
|
||||
encoder(images)
|
||||
for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True)
|
||||
for encoder, images in zip(
|
||||
self.rgb_encoder, images_per_camera, strict=True
|
||||
)
|
||||
]
|
||||
)
|
||||
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
img_features = einops.rearrange(
|
||||
img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||
img_features_list,
|
||||
"(n b s) ... -> b s (n ...)",
|
||||
b=batch_size,
|
||||
s=n_obs_steps,
|
||||
)
|
||||
else:
|
||||
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
|
||||
img_features = self.rgb_encoder(
|
||||
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
||||
einops.rearrange(
|
||||
batch["observation.images"], "b s n ... -> (b s n) ..."
|
||||
)
|
||||
)
|
||||
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
img_features = einops.rearrange(
|
||||
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||
img_features,
|
||||
"(b s n) ... -> b s (n ...)",
|
||||
b=batch_size,
|
||||
s=n_obs_steps,
|
||||
)
|
||||
global_cond_feats.append(img_features)
|
||||
|
||||
@@ -350,7 +373,9 @@ class DiffusionModel(nn.Module):
|
||||
elif self.config.prediction_type == "sample":
|
||||
target = batch["action"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
|
||||
raise ValueError(
|
||||
f"Unsupported prediction type {self.config.prediction_type}"
|
||||
)
|
||||
|
||||
loss = F.mse_loss(pred, target, reduction="none")
|
||||
|
||||
@@ -410,7 +435,9 @@ class SpatialSoftmax(nn.Module):
|
||||
|
||||
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
|
||||
# and causes a small degradation in pc_success of pre-trained models.
|
||||
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
|
||||
pos_x, pos_y = np.meshgrid(
|
||||
np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)
|
||||
)
|
||||
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
|
||||
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
|
||||
# register as buffer so it's moved to the correct device.
|
||||
@@ -452,7 +479,9 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
# Always use center crop for eval
|
||||
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
|
||||
if config.crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(
|
||||
config.crop_shape
|
||||
)
|
||||
else:
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
@@ -473,7 +502,9 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
self.backbone = _replace_submodules(
|
||||
root_module=self.backbone,
|
||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
||||
func=lambda x: nn.GroupNorm(
|
||||
num_groups=x.num_features // 16, num_channels=x.num_features
|
||||
),
|
||||
)
|
||||
|
||||
# Set up pooling and final layers.
|
||||
@@ -515,7 +546,9 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
|
||||
|
||||
def _replace_submodules(
|
||||
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
||||
root_module: nn.Module,
|
||||
predicate: Callable[[nn.Module], bool],
|
||||
func: Callable[[nn.Module], nn.Module],
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Args:
|
||||
@@ -528,7 +561,11 @@ def _replace_submodules(
|
||||
if predicate(root_module):
|
||||
return func(root_module)
|
||||
|
||||
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
||||
replace_list = [
|
||||
k.split(".")
|
||||
for k, m in root_module.named_modules(remove_duplicate=True)
|
||||
if predicate(m)
|
||||
]
|
||||
for *parents, k in replace_list:
|
||||
parent_module = root_module
|
||||
if len(parents) > 0:
|
||||
@@ -543,7 +580,9 @@ def _replace_submodules(
|
||||
else:
|
||||
setattr(parent_module, k, tgt_module)
|
||||
# verify that all BN are replaced
|
||||
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
|
||||
assert not any(
|
||||
predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)
|
||||
)
|
||||
return root_module
|
||||
|
||||
|
||||
@@ -571,7 +610,9 @@ class DiffusionConv1dBlock(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
|
||||
nn.Conv1d(
|
||||
inp_channels, out_channels, kernel_size, padding=kernel_size // 2
|
||||
),
|
||||
nn.GroupNorm(n_groups, out_channels),
|
||||
nn.Mish(),
|
||||
)
|
||||
@@ -594,9 +635,13 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
# Encoder for the diffusion timestep.
|
||||
self.diffusion_step_encoder = nn.Sequential(
|
||||
DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim),
|
||||
nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4),
|
||||
nn.Linear(
|
||||
config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4
|
||||
),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim),
|
||||
nn.Linear(
|
||||
config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim
|
||||
),
|
||||
)
|
||||
|
||||
# The FiLM conditioning dimension.
|
||||
@@ -621,10 +666,16 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
self.down_modules.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs),
|
||||
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
dim_in, dim_out, **common_res_block_kwargs
|
||||
),
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
dim_out, dim_out, **common_res_block_kwargs
|
||||
),
|
||||
# Downsample as long as it is not the last block.
|
||||
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
|
||||
nn.Conv1d(dim_out, dim_out, 3, 2, 1)
|
||||
if not is_last
|
||||
else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
@@ -633,10 +684,14 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
self.mid_modules = nn.ModuleList(
|
||||
[
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
|
||||
config.down_dims[-1],
|
||||
config.down_dims[-1],
|
||||
**common_res_block_kwargs,
|
||||
),
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
|
||||
config.down_dims[-1],
|
||||
config.down_dims[-1],
|
||||
**common_res_block_kwargs,
|
||||
),
|
||||
]
|
||||
)
|
||||
@@ -649,10 +704,16 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
nn.ModuleList(
|
||||
[
|
||||
# dim_in * 2, because it takes the encoder's skip connection as well
|
||||
DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs),
|
||||
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
dim_in * 2, dim_out, **common_res_block_kwargs
|
||||
),
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
dim_out, dim_out, **common_res_block_kwargs
|
||||
),
|
||||
# Upsample as long as it is not the last block.
|
||||
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
|
||||
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1)
|
||||
if not is_last
|
||||
else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
@@ -726,17 +787,23 @@ class DiffusionConditionalResidualBlock1d(nn.Module):
|
||||
self.use_film_scale_modulation = use_film_scale_modulation
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||
self.conv1 = DiffusionConv1dBlock(
|
||||
in_channels, out_channels, kernel_size, n_groups=n_groups
|
||||
)
|
||||
|
||||
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
|
||||
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
|
||||
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
|
||||
|
||||
self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||
self.conv2 = DiffusionConv1dBlock(
|
||||
out_channels, out_channels, kernel_size, n_groups=n_groups
|
||||
)
|
||||
|
||||
# A final convolution for dimension matching the residual (if needed).
|
||||
self.residual_conv = (
|
||||
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
|
||||
nn.Conv1d(in_channels, out_channels, 1)
|
||||
if in_channels != out_channels
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
|
||||
|
||||
@@ -7,7 +7,9 @@ from torch import Tensor, nn
|
||||
|
||||
from .configuration_classifier import ClassifierConfig
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -15,7 +17,10 @@ class ClassifierOutput:
|
||||
"""Wrapper for classifier outputs with additional metadata."""
|
||||
|
||||
def __init__(
|
||||
self, logits: Tensor, probabilities: Optional[Tensor] = None, hidden_states: Optional[Tensor] = None
|
||||
self,
|
||||
logits: Tensor,
|
||||
probabilities: Optional[Tensor] = None,
|
||||
hidden_states: Optional[Tensor] = None,
|
||||
):
|
||||
self.logits = logits
|
||||
self.probabilities = probabilities
|
||||
@@ -43,12 +48,14 @@ class Classifier(
|
||||
name = "classifier"
|
||||
|
||||
def __init__(self, config: ClassifierConfig):
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
from transformers import AutoModel
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
encoder = AutoModel.from_pretrained(
|
||||
self.config.model_name, trust_remote_code=True
|
||||
)
|
||||
# Extract vision model if we're given a multimodal model
|
||||
if hasattr(encoder, "vision_model"):
|
||||
logging.info("Multimodal model detected - using vision encoder only")
|
||||
@@ -74,7 +81,9 @@ class Classifier(
|
||||
self.feature_dim = self.encoder.fc.in_features
|
||||
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
|
||||
elif hasattr(self.encoder.config, "hidden_sizes"):
|
||||
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
|
||||
self.feature_dim = self.encoder.config.hidden_sizes[
|
||||
-1
|
||||
] # Last channel dimension
|
||||
else:
|
||||
raise ValueError("Unsupported CNN architecture")
|
||||
|
||||
@@ -94,14 +103,19 @@ class Classifier(
|
||||
if hasattr(self.encoder.config, "hidden_size"):
|
||||
input_dim = self.encoder.config.hidden_size
|
||||
else:
|
||||
raise ValueError("Unsupported transformer architecture since hidden_size is not found")
|
||||
raise ValueError(
|
||||
"Unsupported transformer architecture since hidden_size is not found"
|
||||
)
|
||||
|
||||
self.classifier_head = nn.Sequential(
|
||||
nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim),
|
||||
nn.Dropout(self.config.dropout_rate),
|
||||
nn.LayerNorm(self.config.hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes),
|
||||
nn.Linear(
|
||||
self.config.hidden_dim,
|
||||
1 if self.config.num_classes == 2 else self.config.num_classes,
|
||||
),
|
||||
)
|
||||
self.classifier_head = self.classifier_head.to(self.config.device)
|
||||
|
||||
@@ -127,7 +141,10 @@ class Classifier(
|
||||
return features
|
||||
else: # Transformer models
|
||||
outputs = self.encoder(processed)
|
||||
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
|
||||
if (
|
||||
hasattr(outputs, "pooler_output")
|
||||
and outputs.pooler_output is not None
|
||||
):
|
||||
return outputs.pooler_output
|
||||
return outputs.last_hidden_state[:, 0, :]
|
||||
|
||||
@@ -143,7 +160,9 @@ class Classifier(
|
||||
else:
|
||||
probabilities = torch.softmax(logits, dim=-1)
|
||||
|
||||
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
|
||||
return ClassifierOutput(
|
||||
logits=logits, probabilities=probabilities, hidden_states=encoder_outputs
|
||||
)
|
||||
|
||||
def predict_reward(self, x, threshold=0.6):
|
||||
if self.config.num_classes == 2:
|
||||
|
||||
@@ -59,7 +59,9 @@ class SACPolicy(
|
||||
config.input_normalization_params
|
||||
)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, input_normalization_params
|
||||
config.input_shapes,
|
||||
config.input_normalization_modes,
|
||||
input_normalization_params,
|
||||
)
|
||||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
@@ -90,7 +92,8 @@ class SACPolicy(
|
||||
ensemble=Ensemble(
|
||||
[
|
||||
CriticHead(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
input_dim=encoder_critic.output_dim
|
||||
+ config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
@@ -104,7 +107,8 @@ class SACPolicy(
|
||||
ensemble=Ensemble(
|
||||
[
|
||||
CriticHead(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
input_dim=encoder_critic.output_dim
|
||||
+ config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
@@ -120,13 +124,17 @@ class SACPolicy(
|
||||
|
||||
self.actor = Policy(
|
||||
encoder=encoder_actor,
|
||||
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
|
||||
network=MLP(
|
||||
input_dim=encoder_actor.output_dim, **config.actor_network_kwargs
|
||||
),
|
||||
action_dim=config.output_shapes["action"][0],
|
||||
encoder_is_shared=config.shared_encoder,
|
||||
**config.policy_kwargs,
|
||||
)
|
||||
if config.target_entropy is None:
|
||||
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
|
||||
config.target_entropy = (
|
||||
-np.prod(config.output_shapes["action"][0]) / 2
|
||||
) # (-dim(A)/2)
|
||||
|
||||
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
||||
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
|
||||
@@ -153,7 +161,11 @@ class SACPolicy(
|
||||
return actions
|
||||
|
||||
def critic_forward(
|
||||
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False, observation_features: Tensor | None = None
|
||||
self,
|
||||
observations: dict[str, Tensor],
|
||||
actions: Tensor,
|
||||
use_target: bool = False,
|
||||
observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
"""Forward pass through a critic network ensemble
|
||||
|
||||
@@ -173,21 +185,37 @@ class SACPolicy(
|
||||
def update_target_networks(self):
|
||||
"""Update target networks with exponential moving average"""
|
||||
for target_param, param in zip(
|
||||
self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=False
|
||||
self.critic_target.parameters(),
|
||||
self.critic_ensemble.parameters(),
|
||||
strict=False,
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
def compute_loss_critic(self, observations, actions, rewards, next_observations, done, observation_features: Tensor | None = None, next_observation_features: Tensor | None = None) -> Tensor:
|
||||
def compute_loss_critic(
|
||||
self,
|
||||
observations,
|
||||
actions,
|
||||
rewards,
|
||||
next_observations,
|
||||
done,
|
||||
observation_features: Tensor | None = None,
|
||||
next_observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
temperature = self.log_alpha.exp().item()
|
||||
with torch.no_grad():
|
||||
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
|
||||
next_action_preds, next_log_probs, _ = self.actor(
|
||||
next_observations, next_observation_features
|
||||
)
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self.critic_forward(
|
||||
observations=next_observations, actions=next_action_preds, use_target=True, observation_features=next_observation_features
|
||||
observations=next_observations,
|
||||
actions=next_action_preds,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
@@ -204,7 +232,12 @@ class SACPolicy(
|
||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
q_preds = self.critic_forward(observations, actions, use_target=False, observation_features=observation_features)
|
||||
q_preds = self.critic_forward(
|
||||
observations,
|
||||
actions,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
@@ -219,20 +252,31 @@ class SACPolicy(
|
||||
).sum()
|
||||
return critics_loss
|
||||
|
||||
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
||||
def compute_loss_temperature(
|
||||
self, observations, observation_features: Tensor | None = None
|
||||
) -> Tensor:
|
||||
"""Compute the temperature loss"""
|
||||
# calculate temperature loss
|
||||
with torch.no_grad():
|
||||
_, log_probs, _ = self.actor(observations, observation_features)
|
||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
|
||||
temperature_loss = (
|
||||
-self.log_alpha.exp() * (log_probs + self.config.target_entropy)
|
||||
).mean()
|
||||
return temperature_loss
|
||||
|
||||
def compute_loss_actor(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
||||
def compute_loss_actor(
|
||||
self, observations, observation_features: Tensor | None = None
|
||||
) -> Tensor:
|
||||
temperature = self.log_alpha.exp().item()
|
||||
|
||||
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
||||
|
||||
q_preds = self.critic_forward(observations, actions_pi, use_target=False, observation_features=observation_features)
|
||||
q_preds = self.critic_forward(
|
||||
observations,
|
||||
actions_pi,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
|
||||
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
|
||||
@@ -259,7 +303,11 @@ class MLP(nn.Module):
|
||||
if dropout_rate is not None and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(hidden_dims[0]))
|
||||
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
|
||||
layers.append(
|
||||
activations
|
||||
if isinstance(activations, nn.Module)
|
||||
else getattr(nn, activations)()
|
||||
)
|
||||
|
||||
# Rest of the layers
|
||||
for i in range(1, len(hidden_dims)):
|
||||
@@ -270,7 +318,9 @@ class MLP(nn.Module):
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(hidden_dims[i]))
|
||||
layers.append(
|
||||
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
|
||||
activations
|
||||
if isinstance(activations, nn.Module)
|
||||
else getattr(nn, activations)()
|
||||
)
|
||||
|
||||
self.net = nn.Sequential(*layers)
|
||||
@@ -381,7 +431,11 @@ class CriticEnsemble(nn.Module):
|
||||
actions = self.output_normalization(actions)["action"]
|
||||
actions = actions.to(device)
|
||||
|
||||
obs_enc = observation_features if observation_features is not None else (observations if self.encoder is None else self.encoder(observations))
|
||||
obs_enc = (
|
||||
observation_features
|
||||
if observation_features is not None
|
||||
else (observations if self.encoder is None else self.encoder(observations))
|
||||
)
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
q_values = self.ensemble(inputs) # [num_critics, B, 1]
|
||||
@@ -445,7 +499,11 @@ class Policy(nn.Module):
|
||||
observation_features: torch.Tensor | None = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Encode observations if encoder exists
|
||||
obs_enc = observation_features if observation_features is not None else (observations if self.encoder is None else self.encoder(observations))
|
||||
obs_enc = (
|
||||
observation_features
|
||||
if observation_features is not None
|
||||
else (observations if self.encoder is None else self.encoder(observations))
|
||||
)
|
||||
|
||||
# Get network outputs
|
||||
outputs = self.network(obs_enc)
|
||||
@@ -454,11 +512,15 @@ class Policy(nn.Module):
|
||||
# Compute standard deviations
|
||||
if self.fixed_std is None:
|
||||
log_std = self.std_layer(outputs)
|
||||
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
|
||||
assert not torch.isnan(
|
||||
log_std
|
||||
).any(), "[ERROR] log_std became NaN after std_layer!"
|
||||
|
||||
if self.use_tanh_squash:
|
||||
log_std = torch.tanh(log_std)
|
||||
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
|
||||
log_std = self.log_std_min + 0.5 * (
|
||||
self.log_std_max - self.log_std_min
|
||||
) * (log_std + 1.0)
|
||||
else:
|
||||
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
|
||||
else:
|
||||
@@ -471,7 +533,9 @@ class Policy(nn.Module):
|
||||
|
||||
if self.use_tanh_squash:
|
||||
actions = torch.tanh(x_t)
|
||||
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
|
||||
log_probs -= torch.log(
|
||||
(1 - actions.pow(2)) + 1e-6
|
||||
) # Adjust log-probs for Tanh
|
||||
else:
|
||||
actions = x_t # No Tanh; raw Gaussian sample
|
||||
|
||||
@@ -518,12 +582,15 @@ class SACObservationEncoder(nn.Module):
|
||||
freeze_image_encoder(self.image_enc_layers)
|
||||
else:
|
||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||
self.all_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
self.all_image_keys = [
|
||||
k for k in config.input_shapes if k.startswith("observation.image")
|
||||
]
|
||||
|
||||
if "observation.state" in config.input_shapes:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
in_features=config.input_shapes["observation.state"][0], out_features=config.latent_dim
|
||||
in_features=config.input_shapes["observation.state"][0],
|
||||
out_features=config.latent_dim,
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.Tanh(),
|
||||
@@ -544,7 +611,9 @@ class SACObservationEncoder(nn.Module):
|
||||
self.aggregation_size += config.latent_dim
|
||||
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
|
||||
|
||||
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
||||
self.aggregation_layer = nn.Linear(
|
||||
in_features=self.aggregation_size, out_features=config.latent_dim
|
||||
)
|
||||
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
|
||||
|
||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||
@@ -557,13 +626,19 @@ class SACObservationEncoder(nn.Module):
|
||||
obs_dict = self.input_normalization(obs_dict)
|
||||
# Batch all images along the batch dimension, then encode them.
|
||||
if len(self.all_image_keys) > 0:
|
||||
images_batched = torch.cat([obs_dict[key] for key in self.all_image_keys], dim=0)
|
||||
images_batched = torch.cat(
|
||||
[obs_dict[key] for key in self.all_image_keys], dim=0
|
||||
)
|
||||
images_batched = self.image_enc_layers(images_batched)
|
||||
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
|
||||
embeddings_chunks = torch.chunk(
|
||||
images_batched, dim=0, chunks=len(self.all_image_keys)
|
||||
)
|
||||
feat.extend(embeddings_chunks)
|
||||
|
||||
if "observation.environment_state" in self.config.input_shapes:
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
feat.append(
|
||||
self.env_state_enc_layers(obs_dict["observation.environment_state"])
|
||||
)
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
|
||||
@@ -631,7 +706,9 @@ class PretrainedImageEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
||||
self.image_enc_layers, self.image_enc_out_shape = (
|
||||
self._load_pretrained_vision_encoder(config)
|
||||
)
|
||||
self.image_enc_proj = nn.Sequential(
|
||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
@@ -642,15 +719,21 @@ class PretrainedImageEncoder(nn.Module):
|
||||
"""Set up CNN encoder"""
|
||||
from transformers import AutoModel
|
||||
|
||||
self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True)
|
||||
self.image_enc_layers = AutoModel.from_pretrained(
|
||||
config.vision_encoder_name, trust_remote_code=True
|
||||
)
|
||||
# self.image_enc_layers.pooler = Identity()
|
||||
|
||||
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
|
||||
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
|
||||
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[
|
||||
-1
|
||||
] # Last channel dimension
|
||||
elif hasattr(self.image_enc_layers, "fc"):
|
||||
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
|
||||
else:
|
||||
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN")
|
||||
raise ValueError(
|
||||
"Unsupported vision encoder architecture, make sure you are using a CNN"
|
||||
)
|
||||
return self.image_enc_layers, self.image_enc_out_shape
|
||||
|
||||
def forward(self, x):
|
||||
@@ -673,7 +756,7 @@ def orthogonal_init():
|
||||
|
||||
class Identity(nn.Module):
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
@@ -701,7 +784,9 @@ class Ensemble(nn.Module):
|
||||
return self.module(*args, **kwargs)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return torch.vmap(self._call, (0, None), randomness="different")(self.params, *args, **kwargs)
|
||||
return torch.vmap(self._call, (0, None), randomness="different")(
|
||||
self.params, *args, **kwargs
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Vectorized {len(self)}x " + self._repr
|
||||
@@ -710,7 +795,9 @@ class Ensemble(nn.Module):
|
||||
# TODO (azouitine): I think in our case this function is not usefull we should remove it
|
||||
# after some investigation
|
||||
# borrowed from tdmpc
|
||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||
def flatten_forward_unflatten(
|
||||
fn: Callable[[Tensor], Tensor], image_tensor: Tensor
|
||||
) -> Tensor:
|
||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||
|
||||
Args:
|
||||
@@ -736,7 +823,9 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
||||
for key, value in inner_dict.items():
|
||||
converted_params[outer_key][key] = torch.tensor(value)
|
||||
if "image" in outer_key:
|
||||
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
|
||||
converted_params[outer_key][key] = converted_params[outer_key][
|
||||
key
|
||||
].view(3, 1, 1)
|
||||
|
||||
return converted_params
|
||||
|
||||
|
||||
@@ -183,7 +183,9 @@ class TDMPCConfig(PreTrainedConfig):
|
||||
"If `n_action_steps > 1`, `n_action_repeats` must be left to its default value of 1."
|
||||
)
|
||||
if not self.use_mpc:
|
||||
raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.")
|
||||
raise ValueError(
|
||||
"If `n_action_steps > 1`, `use_mpc` must be set to `True`."
|
||||
)
|
||||
if self.n_action_steps > self.horizon:
|
||||
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")
|
||||
|
||||
|
||||
@@ -100,7 +100,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
self._queues = {
|
||||
"observation.state": deque(maxlen=1),
|
||||
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
|
||||
"action": deque(
|
||||
maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)
|
||||
),
|
||||
}
|
||||
if self.config.image_features:
|
||||
self._queues["observation.image"] = deque(maxlen=1)
|
||||
@@ -189,7 +191,11 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
|
||||
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
|
||||
# trajectories.
|
||||
z = einops.repeat(z, "b d -> n b d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
|
||||
z = einops.repeat(
|
||||
z,
|
||||
"b d -> n b d",
|
||||
n=self.config.n_gaussian_samples + self.config.n_pi_samples,
|
||||
)
|
||||
|
||||
# Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization
|
||||
# algorithm.
|
||||
@@ -211,35 +217,47 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
self.config.action_feature.shape[0],
|
||||
device=std.device,
|
||||
)
|
||||
gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1)
|
||||
gaussian_actions = torch.clamp(
|
||||
mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1
|
||||
)
|
||||
|
||||
# Compute elite actions.
|
||||
actions = torch.cat([gaussian_actions, pi_actions], dim=1)
|
||||
value = self.estimate_value(z, actions).nan_to_num_(0)
|
||||
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices # (n_elites, batch)
|
||||
elite_idxs = torch.topk(
|
||||
value, self.config.n_elites, dim=0
|
||||
).indices # (n_elites, batch)
|
||||
elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch)
|
||||
# (horizon, n_elites, batch, action_dim)
|
||||
elite_actions = actions.take_along_dim(einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1)
|
||||
elite_actions = actions.take_along_dim(
|
||||
einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1
|
||||
)
|
||||
|
||||
# Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites.
|
||||
max_value = elite_value.max(0, keepdim=True)[0] # (1, batch)
|
||||
# The weighting is a softmax over trajectory values. Note that this is not the same as the usage
|
||||
# of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This
|
||||
# makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²).
|
||||
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
|
||||
score = torch.exp(
|
||||
self.config.elite_weighting_temperature * (elite_value - max_value)
|
||||
)
|
||||
score /= score.sum(axis=0, keepdim=True)
|
||||
# (horizon, batch, action_dim)
|
||||
_mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1)
|
||||
_mean = torch.sum(
|
||||
einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1
|
||||
)
|
||||
_std = torch.sqrt(
|
||||
torch.sum(
|
||||
einops.rearrange(score, "n b -> n b 1")
|
||||
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2,
|
||||
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d"))
|
||||
** 2,
|
||||
dim=1,
|
||||
)
|
||||
)
|
||||
# Update mean with an exponential moving average, and std with a direct replacement.
|
||||
mean = (
|
||||
self.config.gaussian_mean_momentum * mean + (1 - self.config.gaussian_mean_momentum) * _mean
|
||||
self.config.gaussian_mean_momentum * mean
|
||||
+ (1 - self.config.gaussian_mean_momentum) * _mean
|
||||
)
|
||||
std = _std.clamp_(self.config.min_std, self.config.max_std)
|
||||
|
||||
@@ -248,7 +266,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
|
||||
# Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax
|
||||
# scores from the last iteration.
|
||||
actions = elite_actions[:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)]
|
||||
actions = elite_actions[
|
||||
:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)
|
||||
]
|
||||
|
||||
return actions
|
||||
|
||||
@@ -271,7 +291,8 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
# of the FOWM paper.
|
||||
if self.config.uncertainty_regularizer_coeff > 0:
|
||||
regularization = -(
|
||||
self.config.uncertainty_regularizer_coeff * self.model.Qs(z, actions[t]).std(0)
|
||||
self.config.uncertainty_regularizer_coeff
|
||||
* self.model.Qs(z, actions[t]).std(0)
|
||||
)
|
||||
else:
|
||||
regularization = 0
|
||||
@@ -291,15 +312,22 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
if self.config.q_ensemble_size > 2:
|
||||
G += (
|
||||
running_discount
|
||||
* torch.min(terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], dim=0)[
|
||||
0
|
||||
]
|
||||
* torch.min(
|
||||
terminal_values[
|
||||
torch.randint(0, self.config.q_ensemble_size, size=(2,))
|
||||
],
|
||||
dim=0,
|
||||
)[0]
|
||||
)
|
||||
else:
|
||||
G += running_discount * torch.min(terminal_values, dim=0)[0]
|
||||
# Finally, also regularize the terminal value.
|
||||
if self.config.uncertainty_regularizer_coeff > 0:
|
||||
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
|
||||
G -= (
|
||||
running_discount
|
||||
* self.config.uncertainty_regularizer_coeff
|
||||
* terminal_values.std(0)
|
||||
)
|
||||
return G
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
@@ -329,7 +357,10 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
# Apply random image augmentations.
|
||||
if self.config.image_features and self.config.max_random_shift_ratio > 0:
|
||||
observations["observation.image"] = flatten_forward_unflatten(
|
||||
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
|
||||
partial(
|
||||
random_shifts_aug,
|
||||
max_random_shift_ratio=self.config.max_random_shift_ratio,
|
||||
),
|
||||
observations["observation.image"],
|
||||
)
|
||||
|
||||
@@ -347,14 +378,20 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
|
||||
# gives us a next `z`.
|
||||
batch_size = batch["index"].shape[0]
|
||||
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
|
||||
z_preds = torch.empty(
|
||||
horizon + 1, batch_size, self.config.latent_dim, device=device
|
||||
)
|
||||
z_preds[0] = self.model.encode(current_observation)
|
||||
reward_preds = torch.empty_like(reward, device=device)
|
||||
for t in range(horizon):
|
||||
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t])
|
||||
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(
|
||||
z_preds[t], action[t]
|
||||
)
|
||||
|
||||
# Compute Q and V value predictions based on the latent rollout.
|
||||
q_preds_ensemble = self.model.Qs(z_preds[:-1], action) # (ensemble, horizon, batch)
|
||||
q_preds_ensemble = self.model.Qs(
|
||||
z_preds[:-1], action
|
||||
) # (ensemble, horizon, batch)
|
||||
v_preds = self.model.V(z_preds[:-1])
|
||||
info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()})
|
||||
|
||||
@@ -368,10 +405,14 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
# actions (not actions estimated by π).
|
||||
# Note: Here we do not use self.model_target, but self.model. This is to follow the original code
|
||||
# and the FOWM paper.
|
||||
q_targets = reward + self.config.discount * self.model.V(self.model.encode(next_observations))
|
||||
q_targets = reward + self.config.discount * self.model.V(
|
||||
self.model.encode(next_observations)
|
||||
)
|
||||
# From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we
|
||||
# are using them to compute loss for V.
|
||||
v_targets = self.model_target.Qs(z_preds[:-1].detach(), action, return_min=True)
|
||||
v_targets = self.model_target.Qs(
|
||||
z_preds[:-1].detach(), action, return_min=True
|
||||
)
|
||||
|
||||
# Compute losses.
|
||||
# Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the
|
||||
@@ -414,7 +455,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(
|
||||
q_preds_ensemble,
|
||||
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
|
||||
einops.repeat(
|
||||
q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]
|
||||
),
|
||||
reduction="none",
|
||||
).sum(0) # sum over ensemble
|
||||
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||
@@ -452,12 +495,14 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
z_preds = z_preds.detach()
|
||||
# Use stopgrad for the advantage calculation.
|
||||
with torch.no_grad():
|
||||
advantage = self.model_target.Qs(z_preds[:-1], action, return_min=True) - self.model.V(
|
||||
z_preds[:-1]
|
||||
)
|
||||
advantage = self.model_target.Qs(
|
||||
z_preds[:-1], action, return_min=True
|
||||
) - self.model.V(z_preds[:-1])
|
||||
info["advantage"] = advantage[0]
|
||||
# (t, b)
|
||||
exp_advantage = torch.clamp(torch.exp(advantage * self.config.advantage_scaling), max=100.0)
|
||||
exp_advantage = torch.clamp(
|
||||
torch.exp(advantage * self.config.advantage_scaling), max=100.0
|
||||
)
|
||||
action_preds = self.model.pi(z_preds[:-1]) # (t, b, a)
|
||||
# Calculate the MSE between the actions and the action predictions.
|
||||
# Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation
|
||||
@@ -511,7 +556,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
# Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA
|
||||
# update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code
|
||||
# we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995)
|
||||
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
|
||||
update_ema_parameters(
|
||||
self.model_target, self.model, self.config.target_model_momentum
|
||||
)
|
||||
|
||||
|
||||
class TDMPCTOLD(nn.Module):
|
||||
@@ -598,7 +645,9 @@ class TDMPCTOLD(nn.Module):
|
||||
"Sanity check. The last linear layer needs 0 initialization on weights."
|
||||
)
|
||||
nn.init.zeros_(m[-1].weight)
|
||||
nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure
|
||||
nn.init.zeros_(
|
||||
m[-1].bias
|
||||
) # this has already been done, but keep this line here for good measure
|
||||
|
||||
def encode(self, obs: dict[str, Tensor]) -> Tensor:
|
||||
"""Encodes an observation into its latent representation."""
|
||||
@@ -702,11 +751,26 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
|
||||
nn.Conv2d(
|
||||
config.image_encoder_hidden_dim,
|
||||
config.image_encoder_hidden_dim,
|
||||
5,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
||||
nn.Conv2d(
|
||||
config.image_encoder_hidden_dim,
|
||||
config.image_encoder_hidden_dim,
|
||||
3,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
||||
nn.Conv2d(
|
||||
config.image_encoder_hidden_dim,
|
||||
config.image_encoder_hidden_dim,
|
||||
3,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
)
|
||||
dummy_shape = (1, *next(iter(config.image_features.values())).shape)
|
||||
@@ -796,12 +860,17 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
|
||||
"""Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param."""
|
||||
for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True):
|
||||
for (n_p_ema, p_ema), (n_p, p) in zip(
|
||||
ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True
|
||||
ema_module.named_parameters(recurse=False),
|
||||
module.named_parameters(recurse=False),
|
||||
strict=True,
|
||||
):
|
||||
assert n_p_ema == n_p, "Parameter names don't match for EMA model update"
|
||||
if isinstance(p, dict):
|
||||
raise RuntimeError("Dict parameter not supported")
|
||||
if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad:
|
||||
if (
|
||||
isinstance(module, nn.modules.batchnorm._BatchNorm)
|
||||
or not p.requires_grad
|
||||
):
|
||||
# Copy BatchNorm parameters, and non-trainable parameters directly.
|
||||
p_ema.copy_(p.to(dtype=p_ema.dtype).data)
|
||||
with torch.no_grad():
|
||||
@@ -809,7 +878,9 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
|
||||
p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha)
|
||||
|
||||
|
||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||
def flatten_forward_unflatten(
|
||||
fn: Callable[[Tensor], Tensor], image_tensor: Tensor
|
||||
) -> Tensor:
|
||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -145,8 +145,14 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
)
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
|
||||
batch = {
|
||||
k: torch.stack(list(self._queues[k]), dim=1)
|
||||
for k in batch
|
||||
if k in self._queues
|
||||
}
|
||||
actions = self.vqbet(batch, rollout=True)[
|
||||
:, : self.config.action_chunk_size
|
||||
]
|
||||
|
||||
# the dimension of returned action is (batch_size, action_chunk_size, action_dim)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
@@ -168,7 +174,9 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
# n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`.
|
||||
# n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree).
|
||||
loss, n_different_codes, n_different_combinations, recon_l1_error = (
|
||||
self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"])
|
||||
self.vqbet.action_head.discretize(
|
||||
self.config.n_vqvae_training_steps, batch["action"]
|
||||
)
|
||||
)
|
||||
return loss, {
|
||||
"n_different_codes": n_different_codes,
|
||||
@@ -225,7 +233,9 @@ class SpatialSoftmax(nn.Module):
|
||||
|
||||
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
|
||||
# and causes a small degradation in pc_success of pre-trained models.
|
||||
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
|
||||
pos_x, pos_y = np.meshgrid(
|
||||
np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)
|
||||
)
|
||||
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
|
||||
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
|
||||
# register as buffer so it's moved to the correct device.
|
||||
@@ -339,7 +349,12 @@ class VQBeTModel(nn.Module):
|
||||
num_tokens = self.config.n_action_pred_token + self.config.n_obs_steps - 1
|
||||
self.register_buffer(
|
||||
"select_target_actions_indices",
|
||||
torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),
|
||||
torch.row_stack(
|
||||
[
|
||||
torch.arange(i, i + self.config.action_chunk_size)
|
||||
for i in range(num_tokens)
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]:
|
||||
@@ -354,7 +369,11 @@ class VQBeTModel(nn.Module):
|
||||
)
|
||||
# Separate batch and sequence dims.
|
||||
img_features = einops.rearrange(
|
||||
img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images
|
||||
img_features,
|
||||
"(b s n) ... -> b s n ...",
|
||||
b=batch_size,
|
||||
s=n_obs_steps,
|
||||
n=self.num_images,
|
||||
)
|
||||
|
||||
# Arrange prior and current observation step tokens as shown in the class docstring.
|
||||
@@ -366,13 +385,19 @@ class VQBeTModel(nn.Module):
|
||||
input_tokens.append(
|
||||
self.state_projector(batch["observation.state"])
|
||||
) # (batch, obs_step, projection dims)
|
||||
input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps))
|
||||
input_tokens.append(
|
||||
einops.repeat(
|
||||
self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps
|
||||
)
|
||||
)
|
||||
# Interleave tokens by stacking and rearranging.
|
||||
input_tokens = torch.stack(input_tokens, dim=2)
|
||||
input_tokens = einops.rearrange(input_tokens, "b n t d -> b (n t) d")
|
||||
|
||||
len_additional_action_token = self.config.n_action_pred_token - 1
|
||||
future_action_tokens = self.action_token.repeat(batch_size, len_additional_action_token, 1)
|
||||
future_action_tokens = self.action_token.repeat(
|
||||
batch_size, len_additional_action_token, 1
|
||||
)
|
||||
|
||||
# add additional action query tokens for predicting future action chunks
|
||||
input_tokens = torch.cat([input_tokens, future_action_tokens], dim=1)
|
||||
@@ -391,7 +416,11 @@ class VQBeTModel(nn.Module):
|
||||
# Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional).
|
||||
if len_additional_action_token > 0:
|
||||
features = torch.cat(
|
||||
[features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1
|
||||
[
|
||||
features[:, historical_act_pred_index],
|
||||
features[:, -len_additional_action_token:],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
else:
|
||||
features = features[:, historical_act_pred_index]
|
||||
@@ -399,13 +428,15 @@ class VQBeTModel(nn.Module):
|
||||
action_head_output = self.action_head(features)
|
||||
# if rollout, VQ-BeT don't calculate loss
|
||||
if rollout:
|
||||
return action_head_output["predicted_action"][:, n_obs_steps - 1, :].reshape(
|
||||
batch_size, self.config.action_chunk_size, -1
|
||||
)
|
||||
return action_head_output["predicted_action"][
|
||||
:, n_obs_steps - 1, :
|
||||
].reshape(batch_size, self.config.action_chunk_size, -1)
|
||||
# else, it calculate overall loss (bin prediction loss, and offset loss)
|
||||
else:
|
||||
output = batch["action"][:, self.select_target_actions_indices]
|
||||
loss = self.action_head.loss_fn(action_head_output, output, reduction="mean")
|
||||
loss = self.action_head.loss_fn(
|
||||
action_head_output, output, reduction="mean"
|
||||
)
|
||||
return action_head_output, loss
|
||||
|
||||
|
||||
@@ -440,7 +471,9 @@ class VQBeTHead(nn.Module):
|
||||
else:
|
||||
self.map_to_cbet_preds_bin = MLP(
|
||||
in_channels=config.gpt_output_dim,
|
||||
hidden_channels=[self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed],
|
||||
hidden_channels=[
|
||||
self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed
|
||||
],
|
||||
)
|
||||
self.map_to_cbet_preds_offset = MLP(
|
||||
in_channels=config.gpt_output_dim,
|
||||
@@ -467,7 +500,10 @@ class VQBeTHead(nn.Module):
|
||||
|
||||
loss, metric = self.vqvae_model.vqvae_forward(actions)
|
||||
n_different_codes = sum(
|
||||
[len(torch.unique(metric[2][:, i])) for i in range(self.vqvae_model.vqvae_num_layers)]
|
||||
[
|
||||
len(torch.unique(metric[2][:, i]))
|
||||
for i in range(self.vqvae_model.vqvae_num_layers)
|
||||
]
|
||||
)
|
||||
n_different_combinations = len(torch.unique(metric[2], dim=0))
|
||||
recon_l1_error = metric[0].detach().cpu().item()
|
||||
@@ -514,7 +550,13 @@ class VQBeTHead(nn.Module):
|
||||
|
||||
cbet_secondary_logits = self.map_to_cbet_preds_secondary_bin(
|
||||
torch.cat(
|
||||
(x, F.one_hot(sampled_primary_centers, num_classes=self.config.vqvae_n_embed)),
|
||||
(
|
||||
x,
|
||||
F.one_hot(
|
||||
sampled_primary_centers,
|
||||
num_classes=self.config.vqvae_n_embed,
|
||||
),
|
||||
),
|
||||
axis=1,
|
||||
)
|
||||
)
|
||||
@@ -522,19 +564,29 @@ class VQBeTHead(nn.Module):
|
||||
cbet_secondary_logits / self.config.bet_softmax_temperature, dim=-1
|
||||
)
|
||||
sampled_secondary_centers = einops.rearrange(
|
||||
torch.multinomial(cbet_secondary_probs.view(-1, choices), num_samples=1),
|
||||
torch.multinomial(
|
||||
cbet_secondary_probs.view(-1, choices), num_samples=1
|
||||
),
|
||||
"(NT) 1 -> NT",
|
||||
NT=NT,
|
||||
)
|
||||
sampled_centers = torch.stack((sampled_primary_centers, sampled_secondary_centers), axis=1)
|
||||
cbet_logits = torch.stack([cbet_primary_logits, cbet_secondary_logits], dim=1)
|
||||
sampled_centers = torch.stack(
|
||||
(sampled_primary_centers, sampled_secondary_centers), axis=1
|
||||
)
|
||||
cbet_logits = torch.stack(
|
||||
[cbet_primary_logits, cbet_secondary_logits], dim=1
|
||||
)
|
||||
# if self.config.sequentially_select is False, bin prediction head samples primary and secondary code at once.
|
||||
else:
|
||||
cbet_logits = self.map_to_cbet_preds_bin(x)
|
||||
cbet_logits = einops.rearrange(
|
||||
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers
|
||||
cbet_logits,
|
||||
"(NT) (G C) -> (NT) G C",
|
||||
G=self.vqvae_model.vqvae_num_layers,
|
||||
)
|
||||
cbet_probs = torch.softmax(
|
||||
cbet_logits / self.config.bet_softmax_temperature, dim=-1
|
||||
)
|
||||
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
|
||||
NT, G, choices = cbet_probs.shape
|
||||
sampled_centers = einops.rearrange(
|
||||
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
|
||||
@@ -554,9 +606,17 @@ class VQBeTHead(nn.Module):
|
||||
sampled_offsets = sampled_offsets.sum(dim=1)
|
||||
with torch.no_grad():
|
||||
# Get the centroids (= vectors corresponding to the codes) of each layer to pass it through RVQ decoder
|
||||
return_decoder_input = self.vqvae_model.get_embeddings_from_code(sampled_centers).clone().detach()
|
||||
return_decoder_input = (
|
||||
self.vqvae_model.get_embeddings_from_code(sampled_centers)
|
||||
.clone()
|
||||
.detach()
|
||||
)
|
||||
# pass the centroids through decoder to get actions.
|
||||
decoded_action = self.vqvae_model.get_action_from_latent(return_decoder_input).clone().detach()
|
||||
decoded_action = (
|
||||
self.vqvae_model.get_action_from_latent(return_decoder_input)
|
||||
.clone()
|
||||
.detach()
|
||||
)
|
||||
# reshaped extracted offset to match with decoded centroids
|
||||
sampled_offsets = einops.rearrange(
|
||||
sampled_offsets, "NT (W A) -> NT W A", W=self.config.action_chunk_size
|
||||
@@ -605,7 +665,9 @@ class VQBeTHead(nn.Module):
|
||||
# Figure out the loss for the actions.
|
||||
# First, we need to find the closest cluster center for each ground truth action.
|
||||
with torch.no_grad():
|
||||
state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
|
||||
state_vq, action_bins = self.vqvae_model.get_code(
|
||||
action_seq
|
||||
) # action_bins: NT, G
|
||||
|
||||
# Now we can compute the loss.
|
||||
|
||||
@@ -628,8 +690,12 @@ class VQBeTHead(nn.Module):
|
||||
+ cbet_loss2 * self.config.secondary_code_loss_weight
|
||||
)
|
||||
|
||||
equal_primary_code_rate = torch.sum((action_bins[:, 0] == sampled_centers[:, 0]).int()) / (NT)
|
||||
equal_secondary_code_rate = torch.sum((action_bins[:, 1] == sampled_centers[:, 1]).int()) / (NT)
|
||||
equal_primary_code_rate = torch.sum(
|
||||
(action_bins[:, 0] == sampled_centers[:, 0]).int()
|
||||
) / (NT)
|
||||
equal_secondary_code_rate = torch.sum(
|
||||
(action_bins[:, 1] == sampled_centers[:, 1]).int()
|
||||
) / (NT)
|
||||
|
||||
action_mse_error = torch.mean((action_seq - predicted_action) ** 2)
|
||||
vq_action_error = torch.mean(torch.abs(action_seq - decoded_action))
|
||||
@@ -643,7 +709,9 @@ class VQBeTHead(nn.Module):
|
||||
"classification_loss": cbet_loss.detach().cpu().item(),
|
||||
"offset_loss": offset_loss.detach().cpu().item(),
|
||||
"equal_primary_code_rate": equal_primary_code_rate.detach().cpu().item(),
|
||||
"equal_secondary_code_rate": equal_secondary_code_rate.detach().cpu().item(),
|
||||
"equal_secondary_code_rate": equal_secondary_code_rate.detach()
|
||||
.cpu()
|
||||
.item(),
|
||||
"vq_action_error": vq_action_error.detach().cpu().item(),
|
||||
"offset_action_error": offset_action_error.detach().cpu().item(),
|
||||
"action_error_max": action_error_max.detach().cpu().item(),
|
||||
@@ -668,7 +736,9 @@ class VQBeTRgbEncoder(nn.Module):
|
||||
# Always use center crop for eval
|
||||
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
|
||||
if config.crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(
|
||||
config.crop_shape
|
||||
)
|
||||
else:
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
@@ -689,7 +759,9 @@ class VQBeTRgbEncoder(nn.Module):
|
||||
self.backbone = _replace_submodules(
|
||||
root_module=self.backbone,
|
||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
||||
func=lambda x: nn.GroupNorm(
|
||||
num_groups=x.num_features // 16, num_channels=x.num_features
|
||||
),
|
||||
)
|
||||
|
||||
# Set up pooling and final layers.
|
||||
@@ -730,7 +802,9 @@ class VQBeTRgbEncoder(nn.Module):
|
||||
|
||||
|
||||
def _replace_submodules(
|
||||
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
||||
root_module: nn.Module,
|
||||
predicate: Callable[[nn.Module], bool],
|
||||
func: Callable[[nn.Module], nn.Module],
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Args:
|
||||
@@ -743,7 +817,11 @@ def _replace_submodules(
|
||||
if predicate(root_module):
|
||||
return func(root_module)
|
||||
|
||||
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
||||
replace_list = [
|
||||
k.split(".")
|
||||
for k, m in root_module.named_modules(remove_duplicate=True)
|
||||
if predicate(m)
|
||||
]
|
||||
for *parents, k in replace_list:
|
||||
parent_module = root_module
|
||||
if len(parents) > 0:
|
||||
@@ -758,7 +836,9 @@ def _replace_submodules(
|
||||
else:
|
||||
setattr(parent_module, k, tgt_module)
|
||||
# verify that all BN are replaced
|
||||
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
|
||||
assert not any(
|
||||
predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)
|
||||
)
|
||||
return root_module
|
||||
|
||||
|
||||
|
||||
@@ -123,9 +123,15 @@ class CausalSelfAttention(nn.Module):
|
||||
|
||||
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
||||
q, k, v = self.c_attn(x).split(self.gpt_hidden_dim, dim=2)
|
||||
k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
|
||||
1, 2
|
||||
) # (B, nh, T, hs)
|
||||
q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
|
||||
1, 2
|
||||
) # (B, nh, T, hs)
|
||||
v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
|
||||
1, 2
|
||||
) # (B, nh, T, hs)
|
||||
|
||||
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||
@@ -133,7 +139,9 @@ class CausalSelfAttention(nn.Module):
|
||||
att = F.softmax(att, dim=-1)
|
||||
att = self.attn_dropout(att)
|
||||
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
||||
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
||||
y = (
|
||||
y.transpose(1, 2).contiguous().view(B, T, C)
|
||||
) # re-assemble all head outputs side by side
|
||||
|
||||
# output projection
|
||||
y = self.resid_dropout(self.c_proj(y))
|
||||
@@ -189,12 +197,16 @@ class GPT(nn.Module):
|
||||
"ln_f": nn.LayerNorm(config.gpt_hidden_dim),
|
||||
}
|
||||
)
|
||||
self.lm_head = nn.Linear(config.gpt_hidden_dim, config.gpt_output_dim, bias=False)
|
||||
self.lm_head = nn.Linear(
|
||||
config.gpt_hidden_dim, config.gpt_output_dim, bias=False
|
||||
)
|
||||
# init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
|
||||
self.apply(self._init_weights)
|
||||
for pn, p in self.named_parameters():
|
||||
if pn.endswith("c_proj.weight"):
|
||||
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer))
|
||||
torch.nn.init.normal_(
|
||||
p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer)
|
||||
)
|
||||
|
||||
# report number of parameters
|
||||
n_params = sum(p.numel() for p in self.parameters())
|
||||
@@ -208,11 +220,17 @@ class GPT(nn.Module):
|
||||
)
|
||||
|
||||
# positional encodings that are added to the input embeddings
|
||||
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
|
||||
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(
|
||||
0
|
||||
) # shape (1, t)
|
||||
|
||||
# forward the GPT model itself
|
||||
tok_emb = self.transformer.wte(input) # token embeddings of shape (b, t, gpt_hidden_dim)
|
||||
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim)
|
||||
tok_emb = self.transformer.wte(
|
||||
input
|
||||
) # token embeddings of shape (b, t, gpt_hidden_dim)
|
||||
pos_emb = self.transformer.wpe(
|
||||
pos
|
||||
) # position embeddings of shape (1, t, gpt_hidden_dim)
|
||||
x = self.transformer.drop(tok_emb + pos_emb)
|
||||
for block in self.transformer.h:
|
||||
x = block(x)
|
||||
@@ -237,7 +255,9 @@ class GPT(nn.Module):
|
||||
# but want to use a smaller block size for some smaller, simpler model
|
||||
assert gpt_block_size <= self.config.gpt_block_size
|
||||
self.config.gpt_block_size = gpt_block_size
|
||||
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:gpt_block_size])
|
||||
self.transformer.wpe.weight = nn.Parameter(
|
||||
self.transformer.wpe.weight[:gpt_block_size]
|
||||
)
|
||||
for block in self.transformer.h:
|
||||
block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size]
|
||||
|
||||
@@ -270,7 +290,9 @@ class GPT(nn.Module):
|
||||
param_dict = dict(self.named_parameters())
|
||||
inter_params = decay & no_decay
|
||||
union_params = decay | no_decay
|
||||
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
|
||||
assert (
|
||||
len(inter_params) == 0
|
||||
), "parameters {} made it into both decay/no_decay sets!".format(
|
||||
str(inter_params)
|
||||
)
|
||||
assert len(param_dict.keys() - union_params) == 0, (
|
||||
@@ -368,8 +390,12 @@ class ResidualVQ(nn.Module):
|
||||
codebook_input_dim = codebook_dim * heads
|
||||
|
||||
requires_projection = codebook_input_dim != dim
|
||||
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
|
||||
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
|
||||
self.project_in = (
|
||||
nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
|
||||
)
|
||||
self.project_out = (
|
||||
nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
|
||||
)
|
||||
|
||||
self.num_quantizers = num_quantizers
|
||||
|
||||
@@ -377,7 +403,10 @@ class ResidualVQ(nn.Module):
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
VectorQuantize(
|
||||
dim=codebook_dim, codebook_dim=codebook_dim, accept_image_fmap=accept_image_fmap, **kwargs
|
||||
dim=codebook_dim,
|
||||
codebook_dim=codebook_dim,
|
||||
accept_image_fmap=accept_image_fmap,
|
||||
**kwargs,
|
||||
)
|
||||
for _ in range(num_quantizers)
|
||||
]
|
||||
@@ -448,7 +477,9 @@ class ResidualVQ(nn.Module):
|
||||
|
||||
return all_codes
|
||||
|
||||
def forward(self, x, indices=None, return_all_codes=False, sample_codebook_temp=None):
|
||||
def forward(
|
||||
self, x, indices=None, return_all_codes=False, sample_codebook_temp=None
|
||||
):
|
||||
"""
|
||||
For given input tensor x, this function will return the quantized output, the indices of the quantized output, and the loss.
|
||||
First, the input tensor x is projected to the codebook dimension. Then, the input tensor x is passed through Nq layers of VectorQuantize.
|
||||
@@ -477,13 +508,17 @@ class ResidualVQ(nn.Module):
|
||||
)
|
||||
ce_losses = []
|
||||
|
||||
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
|
||||
should_quantize_dropout = (
|
||||
self.training and self.quantize_dropout and not return_loss
|
||||
)
|
||||
|
||||
# sample a layer index at which to dropout further residual quantization
|
||||
# also prepare null indices and loss
|
||||
|
||||
if should_quantize_dropout:
|
||||
rand_quantize_dropout_index = randrange(self.quantize_dropout_cutoff_index, num_quant)
|
||||
rand_quantize_dropout_index = randrange(
|
||||
self.quantize_dropout_cutoff_index, num_quant
|
||||
)
|
||||
|
||||
if quant_dropout_multiple_of != 1:
|
||||
rand_quantize_dropout_index = (
|
||||
@@ -492,14 +527,23 @@ class ResidualVQ(nn.Module):
|
||||
- 1
|
||||
)
|
||||
|
||||
null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2])
|
||||
null_indices = torch.full(null_indices_shape, -1.0, device=device, dtype=torch.long)
|
||||
null_indices_shape = (
|
||||
(x.shape[0], *x.shape[-2:])
|
||||
if self.accept_image_fmap
|
||||
else tuple(x.shape[:2])
|
||||
)
|
||||
null_indices = torch.full(
|
||||
null_indices_shape, -1.0, device=device, dtype=torch.long
|
||||
)
|
||||
null_loss = torch.full((1,), 0.0, device=device, dtype=x.dtype)
|
||||
|
||||
# go through the layers
|
||||
|
||||
for quantizer_index, layer in enumerate(self.layers):
|
||||
if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
|
||||
if (
|
||||
should_quantize_dropout
|
||||
and quantizer_index > rand_quantize_dropout_index
|
||||
):
|
||||
all_indices.append(null_indices)
|
||||
all_losses.append(null_loss)
|
||||
continue
|
||||
@@ -539,7 +583,9 @@ class ResidualVQ(nn.Module):
|
||||
|
||||
# stack all losses and indices
|
||||
|
||||
all_losses, all_indices = map(partial(torch.stack, dim=-1), (all_losses, all_indices))
|
||||
all_losses, all_indices = map(
|
||||
partial(torch.stack, dim=-1), (all_losses, all_indices)
|
||||
)
|
||||
|
||||
ret = (quantized_out, all_indices, all_losses)
|
||||
|
||||
@@ -599,8 +645,12 @@ class VectorQuantize(nn.Module):
|
||||
codebook_input_dim = codebook_dim * heads
|
||||
|
||||
requires_projection = codebook_input_dim != dim
|
||||
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
|
||||
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
|
||||
self.project_in = (
|
||||
nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
|
||||
)
|
||||
self.project_out = (
|
||||
nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
|
||||
)
|
||||
|
||||
self.eps = eps
|
||||
self.commitment_weight = commitment_weight
|
||||
@@ -614,10 +664,14 @@ class VectorQuantize(nn.Module):
|
||||
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
|
||||
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
|
||||
|
||||
assert not (ema_update and learnable_codebook), "learnable codebook not compatible with EMA update"
|
||||
assert not (
|
||||
ema_update and learnable_codebook
|
||||
), "learnable codebook not compatible with EMA update"
|
||||
|
||||
assert 0 <= sync_update_v <= 1.0
|
||||
assert not (sync_update_v > 0.0 and not learnable_codebook), "learnable codebook must be turned on"
|
||||
assert not (
|
||||
sync_update_v > 0.0 and not learnable_codebook
|
||||
), "learnable codebook must be turned on"
|
||||
|
||||
self.sync_update_v = sync_update_v
|
||||
|
||||
@@ -629,7 +683,9 @@ class VectorQuantize(nn.Module):
|
||||
)
|
||||
|
||||
if sync_codebook is None:
|
||||
sync_codebook = distributed.is_initialized() and distributed.get_world_size() > 1
|
||||
sync_codebook = (
|
||||
distributed.is_initialized() and distributed.get_world_size() > 1
|
||||
)
|
||||
|
||||
codebook_kwargs = {
|
||||
"dim": codebook_dim,
|
||||
@@ -794,11 +850,17 @@ class VectorQuantize(nn.Module):
|
||||
|
||||
# quantize again
|
||||
|
||||
quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)
|
||||
quantize, embed_ind, distances = self._codebook(
|
||||
x, **codebook_forward_kwargs
|
||||
)
|
||||
|
||||
if self.training:
|
||||
# determine code to use for commitment loss
|
||||
maybe_detach = torch.detach if not self.learnable_codebook or freeze_codebook else identity
|
||||
maybe_detach = (
|
||||
torch.detach
|
||||
if not self.learnable_codebook or freeze_codebook
|
||||
else identity
|
||||
)
|
||||
|
||||
commit_quantize = maybe_detach(quantize)
|
||||
|
||||
@@ -808,7 +870,9 @@ class VectorQuantize(nn.Module):
|
||||
|
||||
if self.sync_update_v > 0.0:
|
||||
# (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
|
||||
quantize = quantize + self.sync_update_v * (quantize - quantize.detach())
|
||||
quantize = quantize + self.sync_update_v * (
|
||||
quantize - quantize.detach()
|
||||
)
|
||||
|
||||
# function for calculating cross entropy loss to distance matrix
|
||||
# used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss
|
||||
@@ -841,7 +905,9 @@ class VectorQuantize(nn.Module):
|
||||
embed_ind = rearrange(embed_ind, "1 (b h) n -> b n h", h=heads)
|
||||
|
||||
if self.accept_image_fmap:
|
||||
embed_ind = rearrange(embed_ind, "b (h w) ... -> b h w ...", h=height, w=width)
|
||||
embed_ind = rearrange(
|
||||
embed_ind, "b (h w) ... -> b h w ...", h=height, w=width
|
||||
)
|
||||
|
||||
if only_one:
|
||||
embed_ind = rearrange(embed_ind, "b 1 -> b")
|
||||
@@ -895,8 +961,12 @@ class VectorQuantize(nn.Module):
|
||||
|
||||
num_codes = codebook.shape[-2]
|
||||
|
||||
if (self.orthogonal_reg_max_codes is not None) and num_codes > self.orthogonal_reg_max_codes:
|
||||
rand_ids = torch.randperm(num_codes, device=device)[: self.orthogonal_reg_max_codes]
|
||||
if (
|
||||
self.orthogonal_reg_max_codes is not None
|
||||
) and num_codes > self.orthogonal_reg_max_codes:
|
||||
rand_ids = torch.randperm(num_codes, device=device)[
|
||||
: self.orthogonal_reg_max_codes
|
||||
]
|
||||
codebook = codebook[:, rand_ids]
|
||||
|
||||
orthogonal_reg_loss = orthogonal_loss_fn(codebook)
|
||||
@@ -928,7 +998,9 @@ class VectorQuantize(nn.Module):
|
||||
# if masking, only return quantized for where mask has True
|
||||
|
||||
if mask is not None:
|
||||
quantize = torch.where(rearrange(mask, "... -> ... 1"), quantize, orig_input)
|
||||
quantize = torch.where(
|
||||
rearrange(mask, "... -> ... 1"), quantize, orig_input
|
||||
)
|
||||
|
||||
return quantize, embed_ind, loss
|
||||
|
||||
@@ -1038,7 +1110,9 @@ def sample_vectors(samples, num):
|
||||
|
||||
|
||||
def batched_sample_vectors(samples, num):
|
||||
return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0)
|
||||
return torch.stack(
|
||||
[sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0
|
||||
)
|
||||
|
||||
|
||||
def pad_shape(shape, size, dim=0):
|
||||
@@ -1089,7 +1163,9 @@ def sample_vectors_distributed(local_samples, num):
|
||||
all_num_samples = all_gather_sizes(local_samples, dim=0)
|
||||
|
||||
if rank == 0:
|
||||
samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum())
|
||||
samples_per_rank = sample_multinomial(
|
||||
num, all_num_samples / all_num_samples.sum()
|
||||
)
|
||||
else:
|
||||
samples_per_rank = torch.empty_like(all_num_samples)
|
||||
|
||||
@@ -1202,7 +1278,9 @@ class EuclideanCodebook(nn.Module):
|
||||
self.eps = eps
|
||||
self.threshold_ema_dead_code = threshold_ema_dead_code
|
||||
self.reset_cluster_size = (
|
||||
reset_cluster_size if (reset_cluster_size is not None) else threshold_ema_dead_code
|
||||
reset_cluster_size
|
||||
if (reset_cluster_size is not None)
|
||||
else threshold_ema_dead_code
|
||||
)
|
||||
|
||||
assert callable(gumbel_sample)
|
||||
@@ -1213,8 +1291,14 @@ class EuclideanCodebook(nn.Module):
|
||||
"kmeans init is not compatible with multiple codebooks in distributed environment for now"
|
||||
)
|
||||
|
||||
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
|
||||
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
|
||||
self.sample_fn = (
|
||||
sample_vectors_distributed
|
||||
if use_ddp and sync_kmeans
|
||||
else batched_sample_vectors
|
||||
)
|
||||
self.kmeans_all_reduce_fn = (
|
||||
distributed.all_reduce if use_ddp and sync_kmeans else noop
|
||||
)
|
||||
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
|
||||
|
||||
self.register_buffer("initted", torch.Tensor([not kmeans_init]))
|
||||
@@ -1353,7 +1437,9 @@ class EuclideanCodebook(nn.Module):
|
||||
distributed.all_reduce(variance_number)
|
||||
batch_variance = variance_number / num_vectors
|
||||
|
||||
self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay)
|
||||
self.update_with_decay(
|
||||
"batch_variance", batch_variance, self.affine_param_batch_decay
|
||||
)
|
||||
|
||||
def replace(self, batch_samples, batch_mask):
|
||||
for ind, (samples, mask) in enumerate(
|
||||
@@ -1362,7 +1448,9 @@ class EuclideanCodebook(nn.Module):
|
||||
if not torch.any(mask):
|
||||
continue
|
||||
|
||||
sampled = self.sample_fn(rearrange(samples, "... -> 1 ..."), mask.sum().item())
|
||||
sampled = self.sample_fn(
|
||||
rearrange(samples, "... -> 1 ..."), mask.sum().item()
|
||||
)
|
||||
sampled = rearrange(sampled, "1 ... -> ...")
|
||||
|
||||
self.embed.data[ind][mask] = sampled
|
||||
@@ -1386,7 +1474,9 @@ class EuclideanCodebook(nn.Module):
|
||||
def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
|
||||
needs_codebook_dim = x.ndim < 4
|
||||
sample_codebook_temp = (
|
||||
sample_codebook_temp if (sample_codebook_temp is not None) else self.sample_codebook_temp
|
||||
sample_codebook_temp
|
||||
if (sample_codebook_temp is not None)
|
||||
else self.sample_codebook_temp
|
||||
)
|
||||
|
||||
x = x.float()
|
||||
@@ -1414,7 +1504,9 @@ class EuclideanCodebook(nn.Module):
|
||||
if self.affine_param:
|
||||
codebook_std = self.codebook_variance.clamp(min=1e-5).sqrt()
|
||||
batch_std = self.batch_variance.clamp(min=1e-5).sqrt()
|
||||
embed = (embed - self.codebook_mean) * (batch_std / codebook_std) + self.batch_mean
|
||||
embed = (embed - self.codebook_mean) * (
|
||||
batch_std / codebook_std
|
||||
) + self.batch_mean
|
||||
|
||||
dist = -cdist(flatten, embed)
|
||||
|
||||
@@ -1432,7 +1524,9 @@ class EuclideanCodebook(nn.Module):
|
||||
|
||||
if self.training and self.ema_update and not freeze_codebook:
|
||||
if self.affine_param:
|
||||
flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean
|
||||
flatten = (flatten - self.batch_mean) * (
|
||||
codebook_std / batch_std
|
||||
) + self.codebook_mean
|
||||
|
||||
if mask is not None:
|
||||
embed_onehot[~mask] = 0.0
|
||||
@@ -1455,7 +1549,9 @@ class EuclideanCodebook(nn.Module):
|
||||
self.expire_codes_(x)
|
||||
|
||||
if needs_codebook_dim:
|
||||
quantize, embed_ind = tuple(rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind))
|
||||
quantize, embed_ind = tuple(
|
||||
rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind)
|
||||
)
|
||||
|
||||
dist = unpack_one(dist, ps, "h * d")
|
||||
|
||||
|
||||
@@ -79,7 +79,9 @@ def save_image(img_array, serial_number, frame_index, images_dir):
|
||||
img.save(str(path), quality=100)
|
||||
logging.info(f"Saved image: {path}")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to save image for camera {serial_number} frame {frame_index}: {e}")
|
||||
logging.error(
|
||||
f"Failed to save image for camera {serial_number} frame {frame_index}: {e}"
|
||||
)
|
||||
|
||||
|
||||
def save_images_from_cameras(
|
||||
@@ -157,7 +159,9 @@ def save_images_from_cameras(
|
||||
if time.perf_counter() - start_time > record_time_s:
|
||||
break
|
||||
|
||||
print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
|
||||
print(
|
||||
f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}"
|
||||
)
|
||||
|
||||
frame_index += 1
|
||||
finally:
|
||||
@@ -275,7 +279,9 @@ class IntelRealSenseCamera:
|
||||
f"Multiple {name} cameras have been detected. Please use their serial number to instantiate them."
|
||||
)
|
||||
|
||||
name_to_serial_dict = {cam["name"]: cam["serial_number"] for cam in camera_infos}
|
||||
name_to_serial_dict = {
|
||||
cam["name"]: cam["serial_number"] for cam in camera_infos
|
||||
}
|
||||
cam_sn = name_to_serial_dict[name]
|
||||
|
||||
return cam_sn
|
||||
@@ -339,7 +345,9 @@ class IntelRealSenseCamera:
|
||||
actual_height = color_profile.height()
|
||||
|
||||
# Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
|
||||
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
|
||||
if self.fps is not None and not math.isclose(
|
||||
self.fps, actual_fps, rel_tol=1e-3
|
||||
):
|
||||
# Using `OSError` since it's a broad that encompasses issues related to device communication
|
||||
raise OSError(
|
||||
f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}."
|
||||
@@ -359,7 +367,9 @@ class IntelRealSenseCamera:
|
||||
|
||||
self.is_connected = True
|
||||
|
||||
def read(self, temporary_color: str | None = None) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
|
||||
def read(
|
||||
self, temporary_color: str | None = None
|
||||
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
|
||||
"""Read a frame from the camera returned in the format height x width x channels (e.g. 480 x 640 x 3)
|
||||
of type `np.uint8`, contrarily to the pytorch format which is float channel first.
|
||||
|
||||
@@ -386,11 +396,15 @@ class IntelRealSenseCamera:
|
||||
color_frame = frame.get_color_frame()
|
||||
|
||||
if not color_frame:
|
||||
raise OSError(f"Can't capture color image from IntelRealSenseCamera({self.serial_number}).")
|
||||
raise OSError(
|
||||
f"Can't capture color image from IntelRealSenseCamera({self.serial_number})."
|
||||
)
|
||||
|
||||
color_image = np.asanyarray(color_frame.get_data())
|
||||
|
||||
requested_color_mode = self.color_mode if temporary_color is None else temporary_color
|
||||
requested_color_mode = (
|
||||
self.color_mode if temporary_color is None else temporary_color
|
||||
)
|
||||
if requested_color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided."
|
||||
@@ -418,7 +432,9 @@ class IntelRealSenseCamera:
|
||||
if self.use_depth:
|
||||
depth_frame = frame.get_depth_frame()
|
||||
if not depth_frame:
|
||||
raise OSError(f"Can't capture depth image from IntelRealSenseCamera({self.serial_number}).")
|
||||
raise OSError(
|
||||
f"Can't capture depth image from IntelRealSenseCamera({self.serial_number})."
|
||||
)
|
||||
|
||||
depth_map = np.asanyarray(depth_frame.get_data())
|
||||
|
||||
@@ -460,7 +476,9 @@ class IntelRealSenseCamera:
|
||||
# TODO(rcadene, aliberts): intelrealsense has diverged compared to opencv over here
|
||||
num_tries += 1
|
||||
time.sleep(1 / self.fps)
|
||||
if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()):
|
||||
if num_tries > self.fps and (
|
||||
self.thread.ident is None or not self.thread.is_alive()
|
||||
):
|
||||
raise Exception(
|
||||
"The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called."
|
||||
)
|
||||
|
||||
@@ -45,10 +45,14 @@ from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
MAX_OPENCV_INDEX = 60
|
||||
|
||||
|
||||
def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
|
||||
def find_cameras(
|
||||
raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False
|
||||
) -> list[dict]:
|
||||
cameras = []
|
||||
if platform.system() == "Linux":
|
||||
print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports")
|
||||
print(
|
||||
"Linux detected. Finding available camera indices through scanning '/dev/video*' ports"
|
||||
)
|
||||
possible_ports = [str(port) for port in Path("/dev").glob("video*")]
|
||||
ports = _find_cameras(possible_ports, mock=mock)
|
||||
for port in ports:
|
||||
@@ -180,7 +184,9 @@ def save_images_from_cameras(
|
||||
dt_s = time.perf_counter() - now
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
|
||||
print(
|
||||
f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}"
|
||||
)
|
||||
|
||||
if time.perf_counter() - start_time > record_time_s:
|
||||
break
|
||||
@@ -237,7 +243,9 @@ class OpenCVCamera:
|
||||
if platform.system() == "Linux":
|
||||
if isinstance(self.camera_index, int):
|
||||
self.port = Path(f"/dev/video{self.camera_index}")
|
||||
elif isinstance(self.camera_index, str) and is_valid_unix_path(self.camera_index):
|
||||
elif isinstance(self.camera_index, str) and is_valid_unix_path(
|
||||
self.camera_index
|
||||
):
|
||||
self.port = Path(self.camera_index)
|
||||
# Retrieve the camera index from a potentially symlinked path
|
||||
self.camera_index = get_camera_index_from_unix_port(self.port)
|
||||
@@ -283,7 +291,9 @@ class OpenCVCamera:
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected:
|
||||
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
|
||||
raise RobotDeviceAlreadyConnectedError(
|
||||
f"OpenCVCamera({self.camera_index}) is already connected."
|
||||
)
|
||||
|
||||
if self.mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
@@ -344,7 +354,9 @@ class OpenCVCamera:
|
||||
actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
||||
|
||||
# Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
|
||||
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
|
||||
if self.fps is not None and not math.isclose(
|
||||
self.fps, actual_fps, rel_tol=1e-3
|
||||
):
|
||||
# Using `OSError` since it's a broad that encompasses issues related to device communication
|
||||
raise OSError(
|
||||
f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}."
|
||||
@@ -386,7 +398,9 @@ class OpenCVCamera:
|
||||
if not ret:
|
||||
raise OSError(f"Can't capture color image from camera {self.camera_index}.")
|
||||
|
||||
requested_color_mode = self.color_mode if temporary_color_mode is None else temporary_color_mode
|
||||
requested_color_mode = (
|
||||
self.color_mode if temporary_color_mode is None else temporary_color_mode
|
||||
)
|
||||
|
||||
if requested_color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
|
||||
@@ -39,7 +39,9 @@ from lerobot.common.robot_devices.utils import busy_wait
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, has_method
|
||||
|
||||
|
||||
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
|
||||
def log_control_info(
|
||||
robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None
|
||||
):
|
||||
log_items = []
|
||||
if episode_index is not None:
|
||||
log_items.append(f"ep:{episode_index}")
|
||||
@@ -106,7 +108,9 @@ def predict_action(observation, policy, device, use_amp):
|
||||
observation = copy(observation)
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||||
torch.autocast(device_type=device.type)
|
||||
if device.type == "cuda" and use_amp
|
||||
else nullcontext(),
|
||||
):
|
||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||
for name in observation:
|
||||
@@ -162,7 +166,9 @@ def init_keyboard_listener(assign_rewards=False):
|
||||
print("Right arrow key pressed. Exiting loop...")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||
print(
|
||||
"Left arrow key pressed. Exiting loop and rerecord the last episode..."
|
||||
)
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
@@ -256,7 +262,9 @@ def control_loop(
|
||||
raise ValueError("You need to provide a task as argument in `single_task`.")
|
||||
|
||||
if dataset is not None and fps is not None and dataset.fps != fps:
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
||||
raise ValueError(
|
||||
f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps})."
|
||||
)
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
@@ -291,7 +299,9 @@ def control_loop(
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.imshow(
|
||||
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
|
||||
)
|
||||
cv2.waitKey(1)
|
||||
|
||||
if fps is not None:
|
||||
@@ -361,7 +371,11 @@ def sanity_check_dataset_name(repo_id, policy_cfg):
|
||||
|
||||
|
||||
def sanity_check_dataset_robot_compatibility(
|
||||
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool, extra_features: dict = None
|
||||
dataset: LeRobotDataset,
|
||||
robot: Robot,
|
||||
fps: int,
|
||||
use_videos: bool,
|
||||
extra_features: dict = None,
|
||||
) -> None:
|
||||
features_from_robot = get_features_from_robot(robot, use_videos)
|
||||
if extra_features is not None:
|
||||
@@ -375,11 +389,14 @@ def sanity_check_dataset_robot_compatibility(
|
||||
|
||||
mismatches = []
|
||||
for field, dataset_value, present_value in fields:
|
||||
diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"])
|
||||
diff = DeepDiff(
|
||||
dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"]
|
||||
)
|
||||
if diff:
|
||||
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
|
||||
|
||||
if mismatches:
|
||||
raise ValueError(
|
||||
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
|
||||
"Dataset metadata compatibility check failed with mismatches:\n"
|
||||
+ "\n".join(mismatches)
|
||||
)
|
||||
|
||||
@@ -158,7 +158,9 @@ NUM_READ_RETRY = 10
|
||||
NUM_WRITE_RETRY = 10
|
||||
|
||||
|
||||
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
|
||||
def convert_degrees_to_steps(
|
||||
degrees: float | np.ndarray, models: str | list[str]
|
||||
) -> np.ndarray:
|
||||
"""This function converts the degree range to the step range for indicating motors rotation.
|
||||
It assumes a motor achieves a full rotation by going from -180 degree position to +180.
|
||||
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
|
||||
@@ -384,7 +386,9 @@ class DynamixelMotorsBus:
|
||||
indices = []
|
||||
for idx in tqdm.tqdm(possible_ids):
|
||||
try:
|
||||
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
|
||||
present_idx = self.read_with_motor_ids(
|
||||
self.motor_models, [idx], "ID", num_retry=num_retry
|
||||
)[0]
|
||||
except ConnectionError:
|
||||
continue
|
||||
|
||||
@@ -400,7 +404,9 @@ class DynamixelMotorsBus:
|
||||
def set_bus_baudrate(self, baudrate):
|
||||
present_bus_baudrate = self.port_handler.getBaudRate()
|
||||
if present_bus_baudrate != baudrate:
|
||||
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
|
||||
print(
|
||||
f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}."
|
||||
)
|
||||
self.port_handler.setBaudRate(baudrate)
|
||||
|
||||
if self.port_handler.getBaudRate() != baudrate:
|
||||
@@ -421,7 +427,9 @@ class DynamixelMotorsBus:
|
||||
def set_calibration(self, calibration: dict[str, list]):
|
||||
self.calibration = calibration
|
||||
|
||||
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
def apply_calibration_autocorrect(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
"""This function applies the calibration, automatically detects out of range errors for motors values and attempts to correct.
|
||||
|
||||
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
|
||||
@@ -434,7 +442,9 @@ class DynamixelMotorsBus:
|
||||
values = self.apply_calibration(values, motor_names)
|
||||
return values
|
||||
|
||||
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
def apply_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
|
||||
a "zero position" at 0 degree.
|
||||
|
||||
@@ -509,7 +519,9 @@ class DynamixelMotorsBus:
|
||||
|
||||
return values
|
||||
|
||||
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
def autocorrect_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
"""This function automatically detects issues with values of motors after calibration, and correct for these issues.
|
||||
|
||||
Some motors might have values outside of expected maximum bounds after calibration.
|
||||
@@ -551,15 +563,23 @@ class DynamixelMotorsBus:
|
||||
values[i] *= -1
|
||||
|
||||
# Convert from initial range to range [-180, 180] degrees
|
||||
calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
|
||||
calib_val = (
|
||||
(values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
)
|
||||
in_range = (calib_val > LOWER_BOUND_DEGREE) and (
|
||||
calib_val < UPPER_BOUND_DEGREE
|
||||
)
|
||||
|
||||
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees
|
||||
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
|
||||
# (- (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= ((resolution // 2) - values[i] - homing_offset) / resolution
|
||||
low_factor = (-(resolution // 2) - values[i] - homing_offset) / resolution
|
||||
upp_factor = ((resolution // 2) - values[i] - homing_offset) / resolution
|
||||
low_factor = (
|
||||
-(resolution // 2) - values[i] - homing_offset
|
||||
) / resolution
|
||||
upp_factor = (
|
||||
(resolution // 2) - values[i] - homing_offset
|
||||
) / resolution
|
||||
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
start_pos = self.calibration["start_pos"][calib_idx]
|
||||
@@ -567,7 +587,9 @@ class DynamixelMotorsBus:
|
||||
|
||||
# Convert from initial range to range [0, 100] in %
|
||||
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
|
||||
in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
|
||||
in_range = (calib_val > LOWER_BOUND_LINEAR) and (
|
||||
calib_val < UPPER_BOUND_LINEAR
|
||||
)
|
||||
|
||||
# Solve this inequality to find the factor to shift the range into [0, 100] %
|
||||
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
|
||||
@@ -583,19 +605,27 @@ class DynamixelMotorsBus:
|
||||
factor = math.ceil(low_factor)
|
||||
|
||||
if factor > upp_factor:
|
||||
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
|
||||
raise ValueError(
|
||||
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
|
||||
)
|
||||
else:
|
||||
factor = math.ceil(upp_factor)
|
||||
|
||||
if factor > low_factor:
|
||||
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
|
||||
raise ValueError(
|
||||
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
|
||||
)
|
||||
|
||||
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
||||
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
|
||||
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
out_of_range_str = (
|
||||
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
)
|
||||
in_range_str = (
|
||||
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
)
|
||||
|
||||
logging.warning(
|
||||
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
|
||||
@@ -605,7 +635,9 @@ class DynamixelMotorsBus:
|
||||
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
||||
self.calibration["homing_offset"][calib_idx] += resolution * factor
|
||||
|
||||
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
def revert_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
"""Inverse of `apply_calibration`."""
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
@@ -644,7 +676,9 @@ class DynamixelMotorsBus:
|
||||
values = np.round(values).astype(np.int32)
|
||||
return values
|
||||
|
||||
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
||||
def read_with_motor_ids(
|
||||
self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY
|
||||
):
|
||||
if self.mock:
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
@@ -746,7 +780,9 @@ class DynamixelMotorsBus:
|
||||
values = self.apply_calibration_autocorrect(values, motor_names)
|
||||
|
||||
# log the number of seconds it took to read the data from the motors
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
|
||||
delta_ts_name = get_log_name(
|
||||
"delta_timestamp_s", "read", data_name, motor_names
|
||||
)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# log the utc time at which the data was received
|
||||
@@ -755,7 +791,9 @@ class DynamixelMotorsBus:
|
||||
|
||||
return values
|
||||
|
||||
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
||||
def write_with_motor_ids(
|
||||
self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY
|
||||
):
|
||||
if self.mock:
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
@@ -784,7 +822,12 @@ class DynamixelMotorsBus:
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
|
||||
def write(
|
||||
self,
|
||||
data_name,
|
||||
values: int | float | np.ndarray,
|
||||
motor_names: str | list[str] | None = None,
|
||||
):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
|
||||
@@ -845,7 +888,9 @@ class DynamixelMotorsBus:
|
||||
)
|
||||
|
||||
# log the number of seconds it took to write the data to the motors
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
|
||||
delta_ts_name = get_log_name(
|
||||
"delta_timestamp_s", "write", data_name, motor_names
|
||||
)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# TODO(rcadene): should we log the time before sending the write command?
|
||||
|
||||
@@ -137,7 +137,9 @@ NUM_READ_RETRY = 20
|
||||
NUM_WRITE_RETRY = 20
|
||||
|
||||
|
||||
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
|
||||
def convert_degrees_to_steps(
|
||||
degrees: float | np.ndarray, models: str | list[str]
|
||||
) -> np.ndarray:
|
||||
"""This function converts the degree range to the step range for indicating motors rotation.
|
||||
It assumes a motor achieves a full rotation by going from -180 degree position to +180.
|
||||
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
|
||||
@@ -365,7 +367,9 @@ class FeetechMotorsBus:
|
||||
indices = []
|
||||
for idx in tqdm.tqdm(possible_ids):
|
||||
try:
|
||||
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
|
||||
present_idx = self.read_with_motor_ids(
|
||||
self.motor_models, [idx], "ID", num_retry=num_retry
|
||||
)[0]
|
||||
except ConnectionError:
|
||||
continue
|
||||
|
||||
@@ -381,7 +385,9 @@ class FeetechMotorsBus:
|
||||
def set_bus_baudrate(self, baudrate):
|
||||
present_bus_baudrate = self.port_handler.getBaudRate()
|
||||
if present_bus_baudrate != baudrate:
|
||||
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
|
||||
print(
|
||||
f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}."
|
||||
)
|
||||
self.port_handler.setBaudRate(baudrate)
|
||||
|
||||
if self.port_handler.getBaudRate() != baudrate:
|
||||
@@ -402,7 +408,9 @@ class FeetechMotorsBus:
|
||||
def set_calibration(self, calibration: dict[str, list]):
|
||||
self.calibration = calibration
|
||||
|
||||
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
def apply_calibration_autocorrect(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
"""This function apply the calibration, automatically detects out of range errors for motors values and attempt to correct.
|
||||
|
||||
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
|
||||
@@ -415,7 +423,9 @@ class FeetechMotorsBus:
|
||||
values = self.apply_calibration(values, motor_names)
|
||||
return values
|
||||
|
||||
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
def apply_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
|
||||
a "zero position" at 0 degree.
|
||||
|
||||
@@ -489,7 +499,9 @@ class FeetechMotorsBus:
|
||||
|
||||
return values
|
||||
|
||||
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
def autocorrect_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
"""This function automatically detects issues with values of motors after calibration, and correct for these issues.
|
||||
|
||||
Some motors might have values outside of expected maximum bounds after calibration.
|
||||
@@ -528,18 +540,26 @@ class FeetechMotorsBus:
|
||||
values[i] *= -1
|
||||
|
||||
# Convert from initial range to range [-180, 180] degrees
|
||||
calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
|
||||
calib_val = (
|
||||
(values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
)
|
||||
in_range = (calib_val > LOWER_BOUND_DEGREE) and (
|
||||
calib_val < UPPER_BOUND_DEGREE
|
||||
)
|
||||
|
||||
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees
|
||||
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
|
||||
# (- HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= (HALF_TURN_DEGREE / 180 * (resolution // 2) - values[i] - homing_offset) / resolution
|
||||
low_factor = (
|
||||
-HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset
|
||||
-HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2)
|
||||
- values[i]
|
||||
- homing_offset
|
||||
) / resolution
|
||||
upp_factor = (
|
||||
HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset
|
||||
HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2)
|
||||
- values[i]
|
||||
- homing_offset
|
||||
) / resolution
|
||||
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
@@ -548,7 +568,9 @@ class FeetechMotorsBus:
|
||||
|
||||
# Convert from initial range to range [0, 100] in %
|
||||
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
|
||||
in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
|
||||
in_range = (calib_val > LOWER_BOUND_LINEAR) and (
|
||||
calib_val < UPPER_BOUND_LINEAR
|
||||
)
|
||||
|
||||
# Solve this inequality to find the factor to shift the range into [0, 100] %
|
||||
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
|
||||
@@ -564,19 +586,27 @@ class FeetechMotorsBus:
|
||||
factor = math.ceil(low_factor)
|
||||
|
||||
if factor > upp_factor:
|
||||
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
|
||||
raise ValueError(
|
||||
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
|
||||
)
|
||||
else:
|
||||
factor = math.ceil(upp_factor)
|
||||
|
||||
if factor > low_factor:
|
||||
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
|
||||
raise ValueError(
|
||||
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
|
||||
)
|
||||
|
||||
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
||||
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
|
||||
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
out_of_range_str = (
|
||||
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
)
|
||||
in_range_str = (
|
||||
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
)
|
||||
|
||||
logging.warning(
|
||||
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
|
||||
@@ -586,7 +616,9 @@ class FeetechMotorsBus:
|
||||
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
||||
self.calibration["homing_offset"][calib_idx] += resolution * factor
|
||||
|
||||
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
def revert_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
"""Inverse of `apply_calibration`."""
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
@@ -662,7 +694,9 @@ class FeetechMotorsBus:
|
||||
|
||||
return values
|
||||
|
||||
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
||||
def read_with_motor_ids(
|
||||
self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY
|
||||
):
|
||||
if self.mock:
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
else:
|
||||
@@ -771,7 +805,9 @@ class FeetechMotorsBus:
|
||||
values = self.apply_calibration_autocorrect(values, motor_names)
|
||||
|
||||
# log the number of seconds it took to read the data from the motors
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
|
||||
delta_ts_name = get_log_name(
|
||||
"delta_timestamp_s", "read", data_name, motor_names
|
||||
)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# log the utc time at which the data was received
|
||||
@@ -780,7 +816,9 @@ class FeetechMotorsBus:
|
||||
|
||||
return values
|
||||
|
||||
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
||||
def write_with_motor_ids(
|
||||
self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY
|
||||
):
|
||||
if self.mock:
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
else:
|
||||
@@ -809,7 +847,12 @@ class FeetechMotorsBus:
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
|
||||
def write(
|
||||
self,
|
||||
data_name,
|
||||
values: int | float | np.ndarray,
|
||||
motor_names: str | list[str] | None = None,
|
||||
):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
|
||||
@@ -870,7 +913,9 @@ class FeetechMotorsBus:
|
||||
)
|
||||
|
||||
# log the number of seconds it took to write the data to the motors
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
|
||||
delta_ts_name = get_log_name(
|
||||
"delta_timestamp_s", "write", data_name, motor_names
|
||||
)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# TODO(rcadene): should we log the time before sending the write command?
|
||||
|
||||
@@ -24,9 +24,7 @@ from lerobot.common.robot_devices.motors.dynamixel import (
|
||||
)
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
|
||||
URL_TEMPLATE = (
|
||||
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
|
||||
)
|
||||
URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
|
||||
|
||||
# The following positions are provided in nominal degree range ]-180, +180[
|
||||
# For more info on these constants, see comments in the code where they get used.
|
||||
@@ -37,7 +35,9 @@ ROTATED_POSITION_DEGREE = 90
|
||||
def assert_drive_mode(drive_mode):
|
||||
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
|
||||
if not np.all(np.isin(drive_mode, [0, 1])):
|
||||
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
|
||||
raise ValueError(
|
||||
f"`drive_mode` contains values other than 0 or 1: ({drive_mode})"
|
||||
)
|
||||
|
||||
|
||||
def apply_drive_mode(position, drive_mode):
|
||||
@@ -78,12 +78,16 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
|
||||
```
|
||||
"""
|
||||
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
||||
raise ValueError("To run calibration, the torque must be disabled on all motors.")
|
||||
raise ValueError(
|
||||
"To run calibration, the torque must be disabled on all motors."
|
||||
)
|
||||
|
||||
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
|
||||
|
||||
print("\nMove arm to zero position")
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero"))
|
||||
print(
|
||||
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")
|
||||
)
|
||||
input("Press Enter to continue...")
|
||||
|
||||
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
|
||||
@@ -104,10 +108,15 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
|
||||
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
|
||||
# of the previous motor in the kinetic chain.
|
||||
print("\nMove arm to rotated target position")
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
|
||||
print(
|
||||
"See: "
|
||||
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")
|
||||
)
|
||||
input("Press Enter to continue...")
|
||||
|
||||
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
|
||||
rotated_target_pos = convert_degrees_to_steps(
|
||||
ROTATED_POSITION_DEGREE, arm.motor_models
|
||||
)
|
||||
|
||||
# Find drive mode by rotating each motor by a quarter of a turn.
|
||||
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
|
||||
@@ -116,11 +125,15 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
|
||||
|
||||
# Re-compute homing offset to take into account drive mode
|
||||
rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode)
|
||||
rotated_nearest_pos = compute_nearest_rounded_position(rotated_drived_pos, arm.motor_models)
|
||||
rotated_nearest_pos = compute_nearest_rounded_position(
|
||||
rotated_drived_pos, arm.motor_models
|
||||
)
|
||||
homing_offset = rotated_target_pos - rotated_nearest_pos
|
||||
|
||||
print("\nMove arm to rest position")
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest"))
|
||||
print(
|
||||
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")
|
||||
)
|
||||
input("Press Enter to continue...")
|
||||
print()
|
||||
|
||||
|
||||
@@ -26,9 +26,7 @@ from lerobot.common.robot_devices.motors.feetech import (
|
||||
)
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
|
||||
URL_TEMPLATE = (
|
||||
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
|
||||
)
|
||||
URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
|
||||
|
||||
# The following positions are provided in nominal degree range ]-180, +180[
|
||||
# For more info on these constants, see comments in the code where they get used.
|
||||
@@ -39,7 +37,9 @@ ROTATED_POSITION_DEGREE = 90
|
||||
def assert_drive_mode(drive_mode):
|
||||
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
|
||||
if not np.all(np.isin(drive_mode, [0, 1])):
|
||||
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
|
||||
raise ValueError(
|
||||
f"`drive_mode` contains values other than 0 or 1: ({drive_mode})"
|
||||
)
|
||||
|
||||
|
||||
def apply_drive_mode(position, drive_mode):
|
||||
@@ -140,7 +140,9 @@ def apply_offset(calib, offset):
|
||||
return calib
|
||||
|
||||
|
||||
def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
|
||||
def run_arm_auto_calibration(
|
||||
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
|
||||
):
|
||||
if robot_type == "so100":
|
||||
return run_arm_auto_calibration_so100(arm, robot_type, arm_name, arm_type)
|
||||
elif robot_type == "moss":
|
||||
@@ -149,18 +151,27 @@ def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm
|
||||
raise ValueError(robot_type)
|
||||
|
||||
|
||||
def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
|
||||
def run_arm_auto_calibration_so100(
|
||||
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
|
||||
):
|
||||
"""All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
|
||||
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
||||
raise ValueError("To run calibration, the torque must be disabled on all motors.")
|
||||
raise ValueError(
|
||||
"To run calibration, the torque must be disabled on all motors."
|
||||
)
|
||||
|
||||
if not (robot_type == "so100" and arm_type == "follower"):
|
||||
raise NotImplementedError("Auto calibration only supports the follower of so100 arms for now.")
|
||||
raise NotImplementedError(
|
||||
"Auto calibration only supports the follower of so100 arms for now."
|
||||
)
|
||||
|
||||
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
|
||||
|
||||
print("\nMove arm to initial position")
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial"))
|
||||
print(
|
||||
"See: "
|
||||
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")
|
||||
)
|
||||
input("Press Enter to continue...")
|
||||
|
||||
# Lower the acceleration of the motors (in [0,254])
|
||||
@@ -207,11 +218,16 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
|
||||
|
||||
print("Calibrate elbow_flex")
|
||||
calib["elbow_flex"] = move_to_calibrate(
|
||||
arm, "elbow_flex", positive_first=False, in_between_move_hook=in_between_move_hook
|
||||
arm,
|
||||
"elbow_flex",
|
||||
positive_first=False,
|
||||
in_between_move_hook=in_between_move_hook,
|
||||
)
|
||||
calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024)
|
||||
|
||||
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex")
|
||||
arm.write(
|
||||
"Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex"
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
def in_between_move_hook():
|
||||
@@ -239,18 +255,30 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
|
||||
}
|
||||
arm.write("Goal_Position", list(positions.values()), list(positions.keys()))
|
||||
|
||||
arm.write("Goal_Position", round(calib["shoulder_lift"]["zero_pos"] - 1600), "shoulder_lift")
|
||||
arm.write(
|
||||
"Goal_Position",
|
||||
round(calib["shoulder_lift"]["zero_pos"] - 1600),
|
||||
"shoulder_lift",
|
||||
)
|
||||
time.sleep(2)
|
||||
arm.write("Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex")
|
||||
arm.write(
|
||||
"Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex"
|
||||
)
|
||||
time.sleep(2)
|
||||
arm.write("Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex")
|
||||
arm.write(
|
||||
"Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex"
|
||||
)
|
||||
time.sleep(2)
|
||||
arm.write("Goal_Position", round(calib["gripper"]["end_pos"]), "gripper")
|
||||
time.sleep(2)
|
||||
|
||||
print("Calibrate wrist_roll")
|
||||
calib["wrist_roll"] = move_to_calibrate(
|
||||
arm, "wrist_roll", invert_drive_mode=True, positive_first=False, while_move_hook=while_move_hook
|
||||
arm,
|
||||
"wrist_roll",
|
||||
invert_drive_mode=True,
|
||||
positive_first=False,
|
||||
while_move_hook=while_move_hook,
|
||||
)
|
||||
|
||||
arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"], "wrist_roll")
|
||||
@@ -260,7 +288,9 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
|
||||
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"], "wrist_flex")
|
||||
time.sleep(1)
|
||||
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 2048, "elbow_flex")
|
||||
arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift")
|
||||
arm.write(
|
||||
"Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift"
|
||||
)
|
||||
time.sleep(1)
|
||||
arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan")
|
||||
time.sleep(1)
|
||||
@@ -289,18 +319,27 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
|
||||
return calib_dict
|
||||
|
||||
|
||||
def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
|
||||
def run_arm_auto_calibration_moss(
|
||||
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
|
||||
):
|
||||
"""All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
|
||||
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
||||
raise ValueError("To run calibration, the torque must be disabled on all motors.")
|
||||
raise ValueError(
|
||||
"To run calibration, the torque must be disabled on all motors."
|
||||
)
|
||||
|
||||
if not (robot_type == "moss" and arm_type == "follower"):
|
||||
raise NotImplementedError("Auto calibration only supports the follower of moss arms for now.")
|
||||
raise NotImplementedError(
|
||||
"Auto calibration only supports the follower of moss arms for now."
|
||||
)
|
||||
|
||||
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
|
||||
|
||||
print("\nMove arm to initial position")
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial"))
|
||||
print(
|
||||
"See: "
|
||||
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")
|
||||
)
|
||||
input("Press Enter to continue...")
|
||||
|
||||
# Lower the acceleration of the motors (in [0,254])
|
||||
@@ -384,8 +423,12 @@ def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str
|
||||
|
||||
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
|
||||
time.sleep(1)
|
||||
arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift")
|
||||
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex")
|
||||
arm.write(
|
||||
"Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift"
|
||||
)
|
||||
arm.write(
|
||||
"Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex"
|
||||
)
|
||||
time.sleep(2)
|
||||
|
||||
calib_modes = []
|
||||
@@ -412,7 +455,9 @@ def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str
|
||||
return calib_dict
|
||||
|
||||
|
||||
def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
|
||||
def run_arm_manual_calibration(
|
||||
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
|
||||
):
|
||||
"""This function ensures that a neural network trained on data collected on a given robot
|
||||
can work on another robot. For instance before calibration, setting a same goal position
|
||||
for each motor of two different robots will get two very different positions. But after calibration,
|
||||
@@ -435,12 +480,16 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a
|
||||
```
|
||||
"""
|
||||
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
||||
raise ValueError("To run calibration, the torque must be disabled on all motors.")
|
||||
raise ValueError(
|
||||
"To run calibration, the torque must be disabled on all motors."
|
||||
)
|
||||
|
||||
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
|
||||
|
||||
print("\nMove arm to zero position")
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero"))
|
||||
print(
|
||||
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")
|
||||
)
|
||||
input("Press Enter to continue...")
|
||||
|
||||
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
|
||||
@@ -460,10 +509,15 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a
|
||||
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
|
||||
# of the previous motor in the kinetic chain.
|
||||
print("\nMove arm to rotated target position")
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
|
||||
print(
|
||||
"See: "
|
||||
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")
|
||||
)
|
||||
input("Press Enter to continue...")
|
||||
|
||||
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
|
||||
rotated_target_pos = convert_degrees_to_steps(
|
||||
ROTATED_POSITION_DEGREE, arm.motor_models
|
||||
)
|
||||
|
||||
# Find drive mode by rotating each motor by a quarter of a turn.
|
||||
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
|
||||
@@ -475,7 +529,9 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a
|
||||
homing_offset = rotated_target_pos - rotated_drived_pos
|
||||
|
||||
print("\nMove arm to rest position")
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest"))
|
||||
print(
|
||||
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")
|
||||
)
|
||||
input("Press Enter to continue...")
|
||||
print()
|
||||
|
||||
|
||||
@@ -31,11 +31,16 @@ from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
|
||||
from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig
|
||||
from lerobot.common.robot_devices.robots.utils import get_arm_id
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
)
|
||||
|
||||
|
||||
def ensure_safe_goal_position(
|
||||
goal_pos: torch.Tensor, present_pos: torch.Tensor, max_relative_target: float | list[float]
|
||||
goal_pos: torch.Tensor,
|
||||
present_pos: torch.Tensor,
|
||||
max_relative_target: float | list[float],
|
||||
):
|
||||
# Cap relative action target magnitude for safety.
|
||||
diff = goal_pos - present_pos
|
||||
@@ -277,7 +282,9 @@ class ManipulatorRobot:
|
||||
# to squeeze the gripper and have it spring back to an open position on its own.
|
||||
for name in self.leader_arms:
|
||||
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
|
||||
self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper")
|
||||
self.leader_arms[name].write(
|
||||
"Goal_Position", self.config.gripper_open_degree, "gripper"
|
||||
)
|
||||
|
||||
# Check both arms can be read
|
||||
for name in self.follower_arms:
|
||||
@@ -309,18 +316,26 @@ class ManipulatorRobot:
|
||||
print(f"Missing calibration file '{arm_calib_path}'")
|
||||
|
||||
if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
|
||||
from lerobot.common.robot_devices.robots.dynamixel_calibration import run_arm_calibration
|
||||
from lerobot.common.robot_devices.robots.dynamixel_calibration import (
|
||||
run_arm_calibration,
|
||||
)
|
||||
|
||||
calibration = run_arm_calibration(arm, self.robot_type, name, arm_type)
|
||||
calibration = run_arm_calibration(
|
||||
arm, self.robot_type, name, arm_type
|
||||
)
|
||||
|
||||
elif self.robot_type in ["so100", "moss", "lekiwi"]:
|
||||
from lerobot.common.robot_devices.robots.feetech_calibration import (
|
||||
run_arm_manual_calibration,
|
||||
)
|
||||
|
||||
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
|
||||
calibration = run_arm_manual_calibration(
|
||||
arm, self.robot_type, name, arm_type
|
||||
)
|
||||
|
||||
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
||||
print(
|
||||
f"Calibration is done! Saving calibration file '{arm_calib_path}'"
|
||||
)
|
||||
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(arm_calib_path, "w") as f:
|
||||
json.dump(calibration, f)
|
||||
@@ -339,13 +354,17 @@ class ManipulatorRobot:
|
||||
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
|
||||
|
||||
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
||||
raise ValueError("To run set robot preset, the torque must be disabled on all motors.")
|
||||
raise ValueError(
|
||||
"To run set robot preset, the torque must be disabled on all motors."
|
||||
)
|
||||
|
||||
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't
|
||||
# rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm,
|
||||
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
|
||||
all_motors_except_gripper = [name for name in arm.motor_names if name != "gripper"]
|
||||
all_motors_except_gripper = [
|
||||
name for name in arm.motor_names if name != "gripper"
|
||||
]
|
||||
if len(all_motors_except_gripper) > 0:
|
||||
# 4 corresponds to Extended Position on Koch motors
|
||||
arm.write("Operating_Mode", 4, all_motors_except_gripper)
|
||||
@@ -374,7 +393,9 @@ class ManipulatorRobot:
|
||||
# Enable torque on the gripper of the leader arms, and move it to 45 degrees,
|
||||
# so that we can use it as a trigger to close the gripper of the follower arms.
|
||||
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
|
||||
self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper")
|
||||
self.leader_arms[name].write(
|
||||
"Goal_Position", self.config.gripper_open_degree, "gripper"
|
||||
)
|
||||
|
||||
def set_aloha_robot_preset(self):
|
||||
def set_shadow_(arm):
|
||||
@@ -404,11 +425,15 @@ class ManipulatorRobot:
|
||||
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
|
||||
all_motors_except_gripper = [
|
||||
name for name in self.follower_arms[name].motor_names if name != "gripper"
|
||||
name
|
||||
for name in self.follower_arms[name].motor_names
|
||||
if name != "gripper"
|
||||
]
|
||||
if len(all_motors_except_gripper) > 0:
|
||||
# 4 corresponds to Extended Position on Aloha motors
|
||||
self.follower_arms[name].write("Operating_Mode", 4, all_motors_except_gripper)
|
||||
self.follower_arms[name].write(
|
||||
"Operating_Mode", 4, all_motors_except_gripper
|
||||
)
|
||||
|
||||
# Use 'position control current based' for follower gripper to be limited by the limit of the current.
|
||||
# It can grasp an object without forcing too much even tho,
|
||||
@@ -456,7 +481,9 @@ class ManipulatorRobot:
|
||||
before_lread_t = time.perf_counter()
|
||||
leader_pos[name] = self.leader_arms[name].read("Present_Position")
|
||||
leader_pos[name] = torch.from_numpy(leader_pos[name])
|
||||
self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t
|
||||
self.logs[f"read_leader_{name}_pos_dt_s"] = (
|
||||
time.perf_counter() - before_lread_t
|
||||
)
|
||||
|
||||
# Send goal position to the follower
|
||||
follower_goal_pos = {}
|
||||
@@ -477,14 +504,18 @@ class ManipulatorRobot:
|
||||
if self.config.max_relative_target is not None:
|
||||
present_pos = self.follower_arms[name].read("Present_Position")
|
||||
present_pos = torch.from_numpy(present_pos)
|
||||
goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
|
||||
goal_pos = ensure_safe_goal_position(
|
||||
goal_pos, present_pos, self.config.max_relative_target
|
||||
)
|
||||
|
||||
# Used when record_data=True
|
||||
follower_goal_pos[name] = goal_pos
|
||||
|
||||
goal_pos = goal_pos.numpy().astype(np.float32)
|
||||
self.follower_arms[name].write("Goal_Position", goal_pos)
|
||||
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t
|
||||
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = (
|
||||
time.perf_counter() - before_fwrite_t
|
||||
)
|
||||
|
||||
# Early exit when recording data is not requested
|
||||
if not record_data:
|
||||
@@ -497,7 +528,9 @@ class ManipulatorRobot:
|
||||
before_fread_t = time.perf_counter()
|
||||
follower_pos[name] = self.follower_arms[name].read("Present_Position")
|
||||
follower_pos[name] = torch.from_numpy(follower_pos[name])
|
||||
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
|
||||
self.logs[f"read_follower_{name}_pos_dt_s"] = (
|
||||
time.perf_counter() - before_fread_t
|
||||
)
|
||||
|
||||
# Create state by concatenating follower current position
|
||||
state = []
|
||||
@@ -519,8 +552,12 @@ class ManipulatorRobot:
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
images[name] = torch.from_numpy(images[name])
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
|
||||
"delta_timestamp_s"
|
||||
]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = (
|
||||
time.perf_counter() - before_camread_t
|
||||
)
|
||||
|
||||
# Populate output dictionaries
|
||||
obs_dict, action_dict = {}, {}
|
||||
@@ -544,7 +581,9 @@ class ManipulatorRobot:
|
||||
before_fread_t = time.perf_counter()
|
||||
follower_pos[name] = self.follower_arms[name].read("Present_Position")
|
||||
follower_pos[name] = torch.from_numpy(follower_pos[name])
|
||||
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
|
||||
self.logs[f"read_follower_{name}_pos_dt_s"] = (
|
||||
time.perf_counter() - before_fread_t
|
||||
)
|
||||
|
||||
# Create state by concatenating follower current position
|
||||
state = []
|
||||
@@ -559,8 +598,12 @@ class ManipulatorRobot:
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
images[name] = torch.from_numpy(images[name])
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
|
||||
"delta_timestamp_s"
|
||||
]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = (
|
||||
time.perf_counter() - before_camread_t
|
||||
)
|
||||
|
||||
# Populate output dictionaries and format to pytorch
|
||||
obs_dict = {}
|
||||
@@ -606,7 +649,9 @@ class ManipulatorRobot:
|
||||
if self.config.max_relative_target is not None:
|
||||
present_pos = self.follower_arms[name].read("Present_Position")
|
||||
present_pos = torch.from_numpy(present_pos)
|
||||
goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
|
||||
goal_pos = ensure_safe_goal_position(
|
||||
goal_pos, present_pos, self.config.max_relative_target
|
||||
)
|
||||
|
||||
# Save tensor to concat and return
|
||||
action_sent.append(goal_pos)
|
||||
|
||||
@@ -52,7 +52,9 @@ class StretchRobot(StretchAPI):
|
||||
def connect(self) -> None:
|
||||
self.is_connected = self.startup()
|
||||
if not self.is_connected:
|
||||
print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'")
|
||||
print(
|
||||
"Another process is already using Stretch. Try running 'stretch_free_robot_process.py'"
|
||||
)
|
||||
raise ConnectionError()
|
||||
|
||||
for name in self.cameras:
|
||||
@@ -60,7 +62,9 @@ class StretchRobot(StretchAPI):
|
||||
self.is_connected = self.is_connected and self.cameras[name].is_connected
|
||||
|
||||
if not self.is_connected:
|
||||
print("Could not connect to the cameras, check that all cameras are plugged-in.")
|
||||
print(
|
||||
"Could not connect to the cameras, check that all cameras are plugged-in."
|
||||
)
|
||||
raise ConnectionError()
|
||||
|
||||
self.run_calibration()
|
||||
@@ -105,8 +109,12 @@ class StretchRobot(StretchAPI):
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
images[name] = torch.from_numpy(images[name])
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
|
||||
"delta_timestamp_s"
|
||||
]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = (
|
||||
time.perf_counter() - before_camread_t
|
||||
)
|
||||
|
||||
# Populate output dictionaries
|
||||
obs_dict, action_dict = {}, {}
|
||||
@@ -150,8 +158,12 @@ class StretchRobot(StretchAPI):
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
images[name] = torch.from_numpy(images[name])
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
|
||||
"delta_timestamp_s"
|
||||
]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = (
|
||||
time.perf_counter() - before_camread_t
|
||||
)
|
||||
|
||||
# Populate output dictionaries
|
||||
obs_dict = {}
|
||||
|
||||
@@ -48,7 +48,8 @@ class RobotDeviceNotConnectedError(Exception):
|
||||
"""Exception raised when the robot device is not connected."""
|
||||
|
||||
def __init__(
|
||||
self, message="This robot device is not connected. Try calling `robot_device.connect()` first."
|
||||
self,
|
||||
message="This robot device is not connected. Try calling `robot_device.connect()` first.",
|
||||
):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
@@ -17,7 +17,9 @@ import importlib
|
||||
import logging
|
||||
|
||||
|
||||
def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool:
|
||||
def is_package_available(
|
||||
pkg_name: str, return_version: bool = False
|
||||
) -> tuple[bool, str] | bool:
|
||||
"""Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py
|
||||
Check if the package spec exists and grab its version to avoid importing a local directory.
|
||||
**Note:** this doesn't work for all packages.
|
||||
|
||||
@@ -28,7 +28,9 @@ def write_video(video_path, stacked_frames, fps):
|
||||
# Filter out DeprecationWarnings raised from pkg_resources
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning
|
||||
"ignore",
|
||||
"pkg_resources is deprecated as an API",
|
||||
category=DeprecationWarning,
|
||||
)
|
||||
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
||||
|
||||
|
||||
@@ -148,7 +148,10 @@ def _relative_path_between(path1: Path, path2: Path) -> Path:
|
||||
except ValueError: # most likely because path1 is not a subpath of path2
|
||||
common_parts = Path(osp.commonpath([path1, path2])).parts
|
||||
return Path(
|
||||
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
|
||||
"/".join(
|
||||
[".."] * (len(path2.parts) - len(common_parts))
|
||||
+ list(path1.parts[len(common_parts) :])
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -159,10 +162,26 @@ def print_cuda_memory_usage():
|
||||
gc.collect()
|
||||
# Also clear the cache if you want to fully release the memory
|
||||
torch.cuda.empty_cache()
|
||||
print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2))
|
||||
print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2))
|
||||
print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2))
|
||||
print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2))
|
||||
print(
|
||||
"Current GPU Memory Allocated: {:.2f} MB".format(
|
||||
torch.cuda.memory_allocated(0) / 1024**2
|
||||
)
|
||||
)
|
||||
print(
|
||||
"Maximum GPU Memory Allocated: {:.2f} MB".format(
|
||||
torch.cuda.max_memory_allocated(0) / 1024**2
|
||||
)
|
||||
)
|
||||
print(
|
||||
"Current GPU Memory Reserved: {:.2f} MB".format(
|
||||
torch.cuda.memory_reserved(0) / 1024**2
|
||||
)
|
||||
)
|
||||
print(
|
||||
"Maximum GPU Memory Reserved: {:.2f} MB".format(
|
||||
torch.cuda.max_memory_reserved(0) / 1024**2
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def capture_timestamp_utc():
|
||||
@@ -232,7 +251,12 @@ def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
|
||||
|
||||
|
||||
class TimerManager:
|
||||
def __init__(self, elapsed_time_list: list[float] | None = None, label="Elapsed time", log=True):
|
||||
def __init__(
|
||||
self,
|
||||
elapsed_time_list: list[float] | None = None,
|
||||
label="Elapsed time",
|
||||
log=True,
|
||||
):
|
||||
self.label = label
|
||||
self.elapsed_time_list = elapsed_time_list
|
||||
self.log = log
|
||||
|
||||
Reference in New Issue
Block a user