[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
AdilZouitine
parent
761a2dbcb3
commit
8e6d5f504c
@@ -19,7 +19,10 @@ from lerobot.common.datasets.utils import load_image_as_numpy
|
||||
|
||||
|
||||
def estimate_num_samples(
|
||||
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
|
||||
dataset_len: int,
|
||||
min_num_samples: int = 100,
|
||||
max_num_samples: int = 10_000,
|
||||
power: float = 0.75,
|
||||
) -> int:
|
||||
"""Heuristic to estimate the number of samples based on dataset size.
|
||||
The power controls the sample growth relative to dataset size.
|
||||
@@ -43,14 +46,18 @@ def sample_indices(data_len: int) -> list[int]:
|
||||
return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist()
|
||||
|
||||
|
||||
def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300):
|
||||
def auto_downsample_height_width(
|
||||
img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300
|
||||
):
|
||||
_, height, width = img.shape
|
||||
|
||||
if max(width, height) < max_size_threshold:
|
||||
# no downsampling needed
|
||||
return img
|
||||
|
||||
downsample_factor = int(width / target_size) if width > height else int(height / target_size)
|
||||
downsample_factor = (
|
||||
int(width / target_size) if width > height else int(height / target_size)
|
||||
)
|
||||
return img[:, ::downsample_factor, ::downsample_factor]
|
||||
|
||||
|
||||
@@ -72,7 +79,9 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
|
||||
return images
|
||||
|
||||
|
||||
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
|
||||
def get_feature_stats(
|
||||
array: np.ndarray, axis: tuple, keepdims: bool
|
||||
) -> dict[str, np.ndarray]:
|
||||
return {
|
||||
"min": np.min(array, axis=axis, keepdims=keepdims),
|
||||
"max": np.max(array, axis=axis, keepdims=keepdims),
|
||||
@@ -82,7 +91,9 @@ def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[st
|
||||
}
|
||||
|
||||
|
||||
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
|
||||
def compute_episode_stats(
|
||||
episode_data: dict[str, list[str] | np.ndarray], features: dict
|
||||
) -> dict:
|
||||
ep_stats = {}
|
||||
for key, data in episode_data.items():
|
||||
if features[key]["dtype"] == "string":
|
||||
@@ -96,12 +107,15 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu
|
||||
axes_to_reduce = 0 # compute stats over the first axis
|
||||
keepdims = data.ndim == 1 # keep as np.array
|
||||
|
||||
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
|
||||
ep_stats[key] = get_feature_stats(
|
||||
ep_ft_array, axis=axes_to_reduce, keepdims=keepdims
|
||||
)
|
||||
|
||||
# finally, we normalize and remove batch dim for images
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
|
||||
k: v if k == "count" else np.squeeze(v / 255.0, axis=0)
|
||||
for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
return ep_stats
|
||||
@@ -116,14 +130,22 @@ def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
||||
f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
|
||||
)
|
||||
if v.ndim == 0:
|
||||
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
|
||||
raise ValueError(
|
||||
"Number of dimensions must be at least 1, and is 0 instead."
|
||||
)
|
||||
if k == "count" and v.shape != (1,):
|
||||
raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
|
||||
raise ValueError(
|
||||
f"Shape of 'count' must be (1), but is {v.shape} instead."
|
||||
)
|
||||
if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
|
||||
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
|
||||
raise ValueError(
|
||||
f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead."
|
||||
)
|
||||
|
||||
|
||||
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||
def aggregate_feature_stats(
|
||||
stats_ft_list: list[dict[str, dict]],
|
||||
) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Aggregates stats for a single feature."""
|
||||
means = np.stack([s["mean"] for s in stats_ft_list])
|
||||
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
|
||||
@@ -152,7 +174,9 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d
|
||||
}
|
||||
|
||||
|
||||
def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||
def aggregate_stats(
|
||||
stats_list: list[dict[str, dict]],
|
||||
) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
|
||||
|
||||
The final stats will have the union of all data keys from each of the stats dicts.
|
||||
|
||||
@@ -58,7 +58,9 @@ def resolve_delta_timestamps(
|
||||
if key == "action" and cfg.action_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
|
||||
if key.startswith("observation.") and cfg.observation_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
|
||||
delta_timestamps[key] = [
|
||||
i / ds_meta.fps for i in cfg.observation_delta_indices
|
||||
]
|
||||
|
||||
if len(delta_timestamps) == 0:
|
||||
delta_timestamps = None
|
||||
@@ -79,7 +81,9 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
LeRobotDataset | MultiLeRobotDataset
|
||||
"""
|
||||
image_transforms = (
|
||||
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
|
||||
ImageTransforms(cfg.dataset.image_transforms)
|
||||
if cfg.dataset.image_transforms.enable
|
||||
else None
|
||||
)
|
||||
|
||||
if isinstance(cfg.dataset.repo_id, str):
|
||||
@@ -113,6 +117,8 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
if cfg.dataset.use_imagenet_stats:
|
||||
for key in dataset.meta.camera_keys:
|
||||
for stats_type, stats in IMAGENET_STATS.items():
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(
|
||||
stats, dtype=torch.float32
|
||||
)
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -38,10 +38,14 @@ def safe_stop_image_writer(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
|
||||
def image_array_to_pil_image(
|
||||
image_array: np.ndarray, range_check: bool = True
|
||||
) -> PIL.Image.Image:
|
||||
# TODO(aliberts): handle 1 channel and 4 for depth images
|
||||
if image_array.ndim != 3:
|
||||
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
|
||||
raise ValueError(
|
||||
f"The array has {image_array.ndim} dimensions, but 3 is expected for an image."
|
||||
)
|
||||
|
||||
if image_array.shape[0] == 3:
|
||||
# Transpose from pytorch convention (C, H, W) to (H, W, C)
|
||||
|
||||
@@ -108,7 +108,9 @@ class LeRobotDatasetMetadata:
|
||||
self.episodes = load_episodes(self.root)
|
||||
if self._version < packaging.version.parse("v2.1"):
|
||||
self.stats = load_stats(self.root)
|
||||
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
|
||||
self.episodes_stats = backward_compatible_episodes_stats(
|
||||
self.stats, self.episodes
|
||||
)
|
||||
else:
|
||||
self.episodes_stats = load_episodes_stats(self.root)
|
||||
self.stats = aggregate_stats(list(self.episodes_stats.values()))
|
||||
@@ -238,7 +240,9 @@ class LeRobotDatasetMetadata:
|
||||
Given a task in natural language, add it to the dictionary of tasks.
|
||||
"""
|
||||
if task in self.task_to_task_index:
|
||||
raise ValueError(f"The task '{task}' already exists and can't be added twice.")
|
||||
raise ValueError(
|
||||
f"The task '{task}' already exists and can't be added twice."
|
||||
)
|
||||
|
||||
task_index = self.info["total_tasks"]
|
||||
self.task_to_task_index[task] = task_index
|
||||
@@ -281,7 +285,11 @@ class LeRobotDatasetMetadata:
|
||||
write_episode(episode_dict, self.root)
|
||||
|
||||
self.episodes_stats[episode_index] = episode_stats
|
||||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
|
||||
self.stats = (
|
||||
aggregate_stats([self.stats, episode_stats])
|
||||
if self.stats
|
||||
else episode_stats
|
||||
)
|
||||
write_episode_stats(episode_index, episode_stats, self.root)
|
||||
|
||||
def update_video_info(self) -> None:
|
||||
@@ -345,13 +353,17 @@ class LeRobotDatasetMetadata:
|
||||
# as this would break the dict flattening in the stats computation, which uses '/' as separator
|
||||
for key in features:
|
||||
if "/" in key:
|
||||
raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.")
|
||||
raise ValueError(
|
||||
f"Feature names should not contain '/'. Found '/' in feature '{key}'."
|
||||
)
|
||||
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
|
||||
obj.tasks, obj.task_to_task_index = {}, {}
|
||||
obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
|
||||
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
|
||||
obj.info = create_empty_dataset_info(
|
||||
CODEBASE_VERSION, fps, robot_type, features, use_videos
|
||||
)
|
||||
if len(obj.video_keys) > 0 and not use_videos:
|
||||
raise ValueError()
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
@@ -482,7 +494,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
self.video_backend = (
|
||||
video_backend if video_backend else get_safe_default_codec()
|
||||
)
|
||||
self.delta_indices = None
|
||||
|
||||
# Unused attributes
|
||||
@@ -495,28 +509,39 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.meta = LeRobotDatasetMetadata(
|
||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
|
||||
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
||||
if self.episodes is not None and self.meta._version >= packaging.version.parse(
|
||||
"v2.1"
|
||||
):
|
||||
episodes_stats = [
|
||||
self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes
|
||||
]
|
||||
self.stats = aggregate_stats(episodes_stats)
|
||||
|
||||
# Load actual data
|
||||
try:
|
||||
if force_cache_sync:
|
||||
raise FileNotFoundError
|
||||
assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
|
||||
assert all(
|
||||
(self.root / fpath).is_file()
|
||||
for fpath in self.get_episodes_file_paths()
|
||||
)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
self.download_episodes(download_videos)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
self.episode_data_index = get_episode_data_index(
|
||||
self.meta.episodes, self.episodes
|
||||
)
|
||||
|
||||
# Check timestamps
|
||||
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
|
||||
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
|
||||
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
||||
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
||||
check_timestamps_sync(
|
||||
timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s
|
||||
)
|
||||
|
||||
# Setup delta_indices
|
||||
if self.delta_timestamps is not None:
|
||||
@@ -568,7 +593,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
hub_api.upload_folder(**upload_kwargs)
|
||||
|
||||
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
|
||||
if not hub_api.file_exists(
|
||||
self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch
|
||||
):
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
||||
)
|
||||
@@ -576,8 +603,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
if tag_version:
|
||||
with contextlib.suppress(RevisionNotFoundError):
|
||||
hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
||||
hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
hub_api.delete_tag(
|
||||
self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset"
|
||||
)
|
||||
hub_api.create_tag(
|
||||
self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset"
|
||||
)
|
||||
|
||||
def pull_from_repo(
|
||||
self,
|
||||
@@ -609,7 +640,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
|
||||
def get_episodes_file_paths(self) -> list[Path]:
|
||||
episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes))
|
||||
episodes = (
|
||||
self.episodes
|
||||
if self.episodes is not None
|
||||
else list(range(self.meta.total_episodes))
|
||||
)
|
||||
fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
|
||||
if len(self.meta.video_keys) > 0:
|
||||
video_files = [
|
||||
@@ -640,7 +675,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def create_hf_dataset(self) -> datasets.Dataset:
|
||||
features = get_hf_features_from_features(self.features)
|
||||
ft_dict = {col: [] for col in features}
|
||||
hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
|
||||
hf_dataset = datasets.Dataset.from_dict(
|
||||
ft_dict, features=features, split="train"
|
||||
)
|
||||
|
||||
# TODO(aliberts): hf_dataset.set_format("torch")
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
@@ -726,7 +763,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if key not in self.meta.video_keys
|
||||
}
|
||||
|
||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
|
||||
def _query_videos(
|
||||
self, query_timestamps: dict[str, list[float]], ep_idx: int
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
|
||||
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
||||
@@ -735,7 +774,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
item = {}
|
||||
for vid_key, query_ts in query_timestamps.items():
|
||||
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||
frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
|
||||
frames = decode_video_frames(
|
||||
video_path, query_ts, self.tolerance_s, self.video_backend
|
||||
)
|
||||
item[vid_key] = frames.squeeze(0)
|
||||
|
||||
return item
|
||||
@@ -789,7 +830,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
|
||||
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
|
||||
current_ep_idx = (
|
||||
self.meta.total_episodes if episode_index is None else episode_index
|
||||
)
|
||||
ep_buffer = {}
|
||||
# size and task are special cases that are not in self.features
|
||||
ep_buffer["size"] = 0
|
||||
@@ -887,7 +930,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episode_tasks = list(set(tasks))
|
||||
episode_index = episode_buffer["episode_index"]
|
||||
|
||||
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
|
||||
episode_buffer["index"] = np.arange(
|
||||
self.meta.total_frames, self.meta.total_frames + episode_length
|
||||
)
|
||||
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
|
||||
|
||||
# Add new tasks to the tasks dictionary
|
||||
@@ -897,12 +942,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.meta.add_task(task)
|
||||
|
||||
# Given tasks in natural language, find their corresponding task indices
|
||||
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
|
||||
episode_buffer["task_index"] = np.array(
|
||||
[self.meta.get_task_index(task) for task in tasks]
|
||||
)
|
||||
|
||||
for key, ft in self.features.items():
|
||||
# index, episode_index, task_index are already processed above, and image and video
|
||||
# are processed separately by storing image path and frame info as meta data
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in [
|
||||
"image",
|
||||
"video",
|
||||
]:
|
||||
continue
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
|
||||
@@ -944,7 +994,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
|
||||
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
|
||||
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
|
||||
ep_dataset = datasets.Dataset.from_dict(
|
||||
episode_dict, features=self.hf_features, split="train"
|
||||
)
|
||||
ep_dataset = embed_images(ep_dataset)
|
||||
self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset])
|
||||
self.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
@@ -1063,7 +1115,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.delta_timestamps = None
|
||||
obj.delta_indices = None
|
||||
obj.episode_data_index = None
|
||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
obj.video_backend = (
|
||||
video_backend if video_backend is not None else get_safe_default_codec()
|
||||
)
|
||||
return obj
|
||||
|
||||
|
||||
@@ -1088,7 +1142,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME
|
||||
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
|
||||
self.tolerances_s = (
|
||||
tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
|
||||
)
|
||||
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
||||
# are handled by this class.
|
||||
self._datasets = [
|
||||
|
||||
@@ -141,12 +141,16 @@ class SharpnessJitter(Transform):
|
||||
return float(sharpness[0]), float(sharpness[1])
|
||||
|
||||
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
|
||||
sharpness_factor = torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
|
||||
sharpness_factor = (
|
||||
torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
|
||||
)
|
||||
return {"sharpness_factor": sharpness_factor}
|
||||
|
||||
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
|
||||
sharpness_factor = params["sharpness_factor"]
|
||||
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
||||
return self._call_kernel(
|
||||
F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -135,7 +135,9 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
||||
elif isinstance(value, (int, float)):
|
||||
serialized_dict[key] = value
|
||||
else:
|
||||
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
|
||||
raise NotImplementedError(
|
||||
f"The value '{value}' of type '{type(value)}' is not supported."
|
||||
)
|
||||
return unflatten_dict(serialized_dict)
|
||||
|
||||
|
||||
@@ -214,7 +216,10 @@ def write_task(task_index: int, task: dict, local_dir: Path):
|
||||
|
||||
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||
tasks = load_jsonlines(local_dir / TASKS_PATH)
|
||||
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
tasks = {
|
||||
item["task_index"]: item["task"]
|
||||
for item in sorted(tasks, key=lambda x: x["task_index"])
|
||||
}
|
||||
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
|
||||
return tasks, task_to_task_index
|
||||
|
||||
@@ -225,13 +230,19 @@ def write_episode(episode: dict, local_dir: Path):
|
||||
|
||||
def load_episodes(local_dir: Path) -> dict:
|
||||
episodes = load_jsonlines(local_dir / EPISODES_PATH)
|
||||
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
|
||||
return {
|
||||
item["episode_index"]: item
|
||||
for item in sorted(episodes, key=lambda x: x["episode_index"])
|
||||
}
|
||||
|
||||
|
||||
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
|
||||
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
|
||||
# is a dictionary of stats and not an integer.
|
||||
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
|
||||
episode_stats = {
|
||||
"episode_index": episode_index,
|
||||
"stats": serialize_dict(episode_stats),
|
||||
}
|
||||
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
|
||||
|
||||
|
||||
@@ -275,7 +286,9 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
elif first_item is None:
|
||||
pass
|
||||
else:
|
||||
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
||||
items_dict[key] = [
|
||||
x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]
|
||||
]
|
||||
return items_dict
|
||||
|
||||
|
||||
@@ -328,7 +341,9 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
|
||||
Otherwise, will throw a `CompatibilityError`.
|
||||
"""
|
||||
target_version = (
|
||||
packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version
|
||||
packaging.version.parse(version)
|
||||
if not isinstance(version, packaging.version.Version)
|
||||
else version
|
||||
)
|
||||
hub_versions = get_repo_versions(repo_id)
|
||||
|
||||
@@ -349,12 +364,16 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
|
||||
return f"v{target_version}"
|
||||
|
||||
compatibles = [
|
||||
v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor
|
||||
v
|
||||
for v in hub_versions
|
||||
if v.major == target_version.major and v.minor <= target_version.minor
|
||||
]
|
||||
if compatibles:
|
||||
return_version = max(compatibles)
|
||||
if return_version < target_version:
|
||||
logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}")
|
||||
logging.warning(
|
||||
f"Revision {version} for {repo_id} not found, using version v{return_version}"
|
||||
)
|
||||
return f"v{return_version}"
|
||||
|
||||
lower_major = [v for v in hub_versions if v.major < target_version.major]
|
||||
@@ -461,7 +480,9 @@ def create_empty_dataset_info(
|
||||
def get_episode_data_index(
|
||||
episode_dicts: dict[dict], episodes: list[int] | None = None
|
||||
) -> dict[str, torch.Tensor]:
|
||||
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
|
||||
episode_lengths = {
|
||||
ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()
|
||||
}
|
||||
if episodes is not None:
|
||||
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
|
||||
|
||||
@@ -511,7 +532,9 @@ def check_timestamps_sync(
|
||||
|
||||
# Mask to ignore differences at the boundaries between episodes
|
||||
mask = np.ones(len(diffs), dtype=bool)
|
||||
ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode
|
||||
ignored_diffs = (
|
||||
episode_data_index["to"][:-1] - 1
|
||||
) # indices at the end of each episode
|
||||
mask[ignored_diffs] = False
|
||||
filtered_within_tolerance = within_tolerance[mask]
|
||||
|
||||
@@ -720,14 +743,18 @@ def validate_frame(frame: dict, features: dict):
|
||||
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
|
||||
actual_features = set(frame.keys())
|
||||
|
||||
error_message = validate_features_presence(actual_features, expected_features, optional_features)
|
||||
error_message = validate_features_presence(
|
||||
actual_features, expected_features, optional_features
|
||||
)
|
||||
|
||||
if "task" in frame:
|
||||
error_message += validate_feature_string("task", frame["task"])
|
||||
|
||||
common_features = actual_features & (expected_features | optional_features)
|
||||
for name in common_features - {"task"}:
|
||||
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
|
||||
error_message += validate_feature_dtype_and_shape(
|
||||
name, features[name], frame[name]
|
||||
)
|
||||
|
||||
if error_message:
|
||||
raise ValueError(error_message)
|
||||
@@ -750,7 +777,9 @@ def validate_features_presence(
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
|
||||
def validate_feature_dtype_and_shape(
|
||||
name: str, feature: dict, value: np.ndarray | PILImage.Image | str
|
||||
):
|
||||
expected_dtype = feature["dtype"]
|
||||
expected_shape = feature["shape"]
|
||||
if is_valid_numpy_dtype_string(expected_dtype):
|
||||
@@ -760,7 +789,9 @@ def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray
|
||||
elif expected_dtype == "string":
|
||||
return validate_feature_string(name, value)
|
||||
else:
|
||||
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
||||
raise NotImplementedError(
|
||||
f"The feature dtype '{expected_dtype}' is not implemented yet."
|
||||
)
|
||||
|
||||
|
||||
def validate_feature_numpy_array(
|
||||
@@ -782,13 +813,17 @@ def validate_feature_numpy_array(
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
|
||||
def validate_feature_image_or_video(
|
||||
name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image
|
||||
):
|
||||
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_shape = value.shape
|
||||
c, h, w = expected_shape
|
||||
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
|
||||
if len(actual_shape) != 3 or (
|
||||
actual_shape != (c, h, w) and actual_shape != (h, w, c)
|
||||
):
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
|
||||
elif isinstance(value, PILImage.Image):
|
||||
pass
|
||||
@@ -819,7 +854,9 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
|
||||
)
|
||||
|
||||
if episode_buffer["size"] == 0:
|
||||
raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
|
||||
raise ValueError(
|
||||
"You must add one or several frames with `add_frame` before calling `add_episode`."
|
||||
)
|
||||
|
||||
buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
|
||||
if not buffer_keys == set(features):
|
||||
|
||||
@@ -35,22 +35,30 @@ def fix_dataset(repo_id: str) -> str:
|
||||
|
||||
dataset_info = get_dataset_config_info(repo_id, "default")
|
||||
with SuppressWarnings():
|
||||
lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True)
|
||||
lerobot_metadata = LeRobotDatasetMetadata(
|
||||
repo_id, revision=V20, force_cache_sync=True
|
||||
)
|
||||
|
||||
meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"}
|
||||
meta_features = {
|
||||
key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"
|
||||
}
|
||||
parquet_features = set(dataset_info.features)
|
||||
|
||||
diff_parquet_meta = parquet_features - meta_features
|
||||
diff_meta_parquet = meta_features - parquet_features
|
||||
|
||||
if diff_parquet_meta:
|
||||
raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}")
|
||||
raise ValueError(
|
||||
f"In parquet not in info.json: {parquet_features - meta_features}"
|
||||
)
|
||||
|
||||
if not diff_meta_parquet:
|
||||
return f"{repo_id}: skipped (no diff)"
|
||||
|
||||
if diff_meta_parquet:
|
||||
logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
|
||||
logging.warning(
|
||||
f"In info.json not in parquet: {meta_features - parquet_features}"
|
||||
)
|
||||
assert diff_meta_parquet == {"language_instruction"}
|
||||
lerobot_metadata.features.pop("language_instruction")
|
||||
write_info(lerobot_metadata.info, lerobot_metadata.root)
|
||||
|
||||
@@ -37,8 +37,16 @@ import logging
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
|
||||
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
|
||||
from lerobot.common.datasets.utils import (
|
||||
EPISODES_STATS_PATH,
|
||||
STATS_PATH,
|
||||
load_stats,
|
||||
write_info,
|
||||
)
|
||||
from lerobot.common.datasets.v21.convert_stats import (
|
||||
check_aggregate_stats,
|
||||
convert_stats,
|
||||
)
|
||||
|
||||
V20 = "v2.0"
|
||||
V21 = "v2.1"
|
||||
@@ -79,13 +87,21 @@ def convert_dataset(
|
||||
|
||||
hub_api = HfApi()
|
||||
if hub_api.file_exists(
|
||||
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
|
||||
repo_id=dataset.repo_id,
|
||||
filename=STATS_PATH,
|
||||
revision=branch,
|
||||
repo_type="dataset",
|
||||
):
|
||||
hub_api.delete_file(
|
||||
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
|
||||
path_in_repo=STATS_PATH,
|
||||
repo_id=dataset.repo_id,
|
||||
revision=branch,
|
||||
repo_type="dataset",
|
||||
)
|
||||
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
hub_api.create_tag(
|
||||
repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -17,12 +17,18 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
||||
from lerobot.common.datasets.compute_stats import (
|
||||
aggregate_stats,
|
||||
get_feature_stats,
|
||||
sample_indices,
|
||||
)
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import write_episode_stats
|
||||
|
||||
|
||||
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
|
||||
def sample_episode_video_frames(
|
||||
dataset: LeRobotDataset, episode_index: int, ft_key: str
|
||||
) -> np.ndarray:
|
||||
ep_len = dataset.meta.episodes[episode_index]["length"]
|
||||
sampled_indices = sample_indices(ep_len)
|
||||
query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
|
||||
@@ -45,11 +51,14 @@ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
||||
|
||||
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
|
||||
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
|
||||
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
|
||||
ep_stats[key] = get_feature_stats(
|
||||
ep_ft_data, axis=axes_to_reduce, keepdims=keepdims
|
||||
)
|
||||
|
||||
if ft["dtype"] in ["image", "video"]: # remove batch dim
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
|
||||
k: v if k == "count" else np.squeeze(v, axis=0)
|
||||
for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
dataset.meta.episodes_stats[ep_idx] = ep_stats
|
||||
@@ -95,5 +104,9 @@ def check_aggregate_stats(
|
||||
if key in reference_stats and stat in reference_stats[key]:
|
||||
err_msg = f"feature='{key}' stats='{stat}'"
|
||||
np.testing.assert_allclose(
|
||||
val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
|
||||
val,
|
||||
reference_stats[key][stat],
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
err_msg=err_msg,
|
||||
)
|
||||
|
||||
@@ -65,7 +65,9 @@ def decode_video_frames(
|
||||
if backend == "torchcodec":
|
||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
||||
elif backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
return decode_video_frames_torchvision(
|
||||
video_path, timestamps, tolerance_s, backend
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported video backend: {backend}")
|
||||
|
||||
|
||||
@@ -61,10 +61,16 @@ class AlohaEnv(EnvConfig):
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels":
|
||||
self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
|
||||
self.features["top"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(480, 640, 3)
|
||||
)
|
||||
elif self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,))
|
||||
self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
|
||||
self.features["agent_pos"] = PolicyFeature(
|
||||
type=FeatureType.STATE, shape=(14,)
|
||||
)
|
||||
self.features["pixels/top"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(480, 640, 3)
|
||||
)
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
@@ -102,9 +108,13 @@ class PushtEnv(EnvConfig):
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels_agent_pos":
|
||||
self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3))
|
||||
self.features["pixels"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(384, 384, 3)
|
||||
)
|
||||
elif self.obs_type == "environment_state_agent_pos":
|
||||
self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,))
|
||||
self.features["environment_state"] = PolicyFeature(
|
||||
type=FeatureType.ENV, shape=(16,)
|
||||
)
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
@@ -143,7 +153,9 @@ class XarmEnv(EnvConfig):
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
|
||||
self.features["agent_pos"] = PolicyFeature(
|
||||
type=FeatureType.STATE, shape=(4,)
|
||||
)
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
|
||||
@@ -32,7 +32,9 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
|
||||
|
||||
def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> gym.vector.VectorEnv | None:
|
||||
def make_env(
|
||||
cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False
|
||||
) -> gym.vector.VectorEnv | None:
|
||||
"""Makes a gym vector environment according to the config.
|
||||
|
||||
Args:
|
||||
@@ -56,7 +58,9 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
|
||||
try:
|
||||
importlib.import_module(package_name)
|
||||
except ModuleNotFoundError as e:
|
||||
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
|
||||
print(
|
||||
f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`"
|
||||
)
|
||||
raise e
|
||||
|
||||
gym_handle = f"{package_name}/{cfg.task}"
|
||||
@@ -64,7 +68,10 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
|
||||
# batched version of the env that returns an observation of shape (b, c)
|
||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||
env = env_cls(
|
||||
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
|
||||
[
|
||||
lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs)
|
||||
for _ in range(n_envs)
|
||||
]
|
||||
)
|
||||
|
||||
return env
|
||||
|
||||
@@ -46,7 +46,9 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
assert c < h and c < w, (
|
||||
f"expect channel last images, but instead got {img.shape=}"
|
||||
)
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
@@ -79,7 +81,9 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||
for key, ft in env_cfg.features.items():
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
if len(ft.shape) != 3:
|
||||
raise ValueError(f"Number of dimensions of {key} != 3 (shape={ft.shape})")
|
||||
raise ValueError(
|
||||
f"Number of dimensions of {key} != 3 (shape={ft.shape})"
|
||||
)
|
||||
|
||||
shape = get_channel_first_image_shape(ft.shape)
|
||||
feature = PolicyFeature(type=ft.type, shape=shape)
|
||||
@@ -92,7 +96,9 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||
return policy_features
|
||||
|
||||
|
||||
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
||||
def preprocess_maniskill_observation(
|
||||
observations: dict[str, np.ndarray],
|
||||
) -> dict[str, Tensor]:
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
Args:
|
||||
observation: Dictionary of observation batches from a Gym vector environment.
|
||||
|
||||
@@ -250,9 +250,9 @@ class Logger:
|
||||
)
|
||||
# For the case where the optimizer is a dictionary of optimizers (e.g., sac)
|
||||
if type(training_state["optimizer"]) is dict:
|
||||
assert set(training_state["optimizer"].keys()) == set(
|
||||
optimizer.keys()
|
||||
), "Optimizer dictionaries do not have the same keys during resume!"
|
||||
assert set(training_state["optimizer"].keys()) == set(optimizer.keys()), (
|
||||
"Optimizer dictionaries do not have the same keys during resume!"
|
||||
)
|
||||
for k, v in training_state["optimizer"].items():
|
||||
optimizer[k].load_state_dict(v)
|
||||
else:
|
||||
|
||||
@@ -34,7 +34,13 @@ def make_optimizer_and_scheduler(
|
||||
Returns:
|
||||
tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`.
|
||||
"""
|
||||
params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
|
||||
params = (
|
||||
policy.get_optim_params()
|
||||
if cfg.use_policy_training_preset
|
||||
else policy.parameters()
|
||||
)
|
||||
optimizer = cfg.optimizer.build(params)
|
||||
lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
|
||||
lr_scheduler = (
|
||||
cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
|
||||
)
|
||||
return optimizer, lr_scheduler
|
||||
|
||||
@@ -102,7 +102,9 @@ def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> No
|
||||
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
|
||||
|
||||
|
||||
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
|
||||
def load_optimizer_state(
|
||||
optimizer: torch.optim.Optimizer, save_dir: Path
|
||||
) -> torch.optim.Optimizer:
|
||||
current_state_dict = optimizer.state_dict()
|
||||
flat_state = load_file(save_dir / OPTIMIZER_STATE)
|
||||
state = unflatten_dict(flat_state)
|
||||
|
||||
@@ -36,7 +36,9 @@ class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@abc.abstractmethod
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler | None:
|
||||
def build(
|
||||
self, optimizer: Optimizer, num_training_steps: int
|
||||
) -> LRScheduler | None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -49,7 +51,11 @@ class DiffuserSchedulerConfig(LRSchedulerConfig):
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer}
|
||||
kwargs = {
|
||||
**asdict(self),
|
||||
"num_training_steps": num_training_steps,
|
||||
"optimizer": optimizer,
|
||||
}
|
||||
return get_scheduler(**kwargs)
|
||||
|
||||
|
||||
@@ -71,7 +77,14 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
|
||||
progress = float(adjusted_step - self.num_warmup_steps) / float(
|
||||
max(1, num_training_steps - self.num_warmup_steps)
|
||||
)
|
||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))
|
||||
return max(
|
||||
0.0,
|
||||
0.5
|
||||
* (
|
||||
1.0
|
||||
+ math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)
|
||||
),
|
||||
)
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
@@ -98,7 +111,9 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
|
||||
def cosine_decay_schedule(current_step):
|
||||
step = min(current_step, self.num_decay_steps)
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
|
||||
cosine_decay = 0.5 * (
|
||||
1 + math.cos(math.pi * step / self.num_decay_steps)
|
||||
)
|
||||
alpha = self.decay_lr / self.peak_lr
|
||||
decayed = (1 - alpha) * cosine_decay + alpha
|
||||
return decayed
|
||||
@@ -117,6 +132,8 @@ def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
|
||||
|
||||
|
||||
def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler:
|
||||
state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict())
|
||||
state_dict = deserialize_json_into_object(
|
||||
save_dir / SCHEDULER_STATE, scheduler.state_dict()
|
||||
)
|
||||
scheduler.load_state_dict(state_dict)
|
||||
return scheduler
|
||||
|
||||
@@ -171,7 +171,9 @@ class ACTConfig(PreTrainedConfig):
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.image_features and not self.env_state_feature:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
raise ValueError(
|
||||
"You must provide at least one image or the environment state among the inputs."
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
|
||||
@@ -63,7 +63,9 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
@@ -120,8 +122,12 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = [batch[key] for key in self.config.image_features]
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = [
|
||||
batch[key] for key in self.config.image_features
|
||||
]
|
||||
|
||||
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
||||
# we are ensembling over.
|
||||
@@ -148,8 +154,12 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = [batch[key] for key in self.config.image_features]
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = [
|
||||
batch[key] for key in self.config.image_features
|
||||
]
|
||||
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
@@ -406,14 +416,18 @@ class ACT(nn.Module):
|
||||
n_1d_tokens += 1
|
||||
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
|
||||
if self.config.image_features:
|
||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(
|
||||
config.dim_model // 2
|
||||
)
|
||||
|
||||
# Transformer decoder.
|
||||
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
|
||||
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)
|
||||
|
||||
# Final action regression head on the output of the transformer's decoder.
|
||||
self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0])
|
||||
self.action_head = nn.Linear(
|
||||
config.dim_model, self.config.action_feature.shape[0]
|
||||
)
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
@@ -461,14 +475,20 @@ class ACT(nn.Module):
|
||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||
) # (B, 1, D)
|
||||
if self.config.robot_state_feature:
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(
|
||||
batch["observation.state"]
|
||||
)
|
||||
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(
|
||||
batch["action"]
|
||||
) # (B, S, D)
|
||||
|
||||
if self.config.robot_state_feature:
|
||||
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
||||
vae_encoder_input = [
|
||||
cls_embed,
|
||||
robot_state_embed,
|
||||
action_embed,
|
||||
] # (B, S+2, D)
|
||||
else:
|
||||
vae_encoder_input = [cls_embed, action_embed]
|
||||
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
|
||||
@@ -517,7 +537,9 @@ class ACT(nn.Module):
|
||||
)
|
||||
# Robot state token.
|
||||
if self.config.robot_state_feature:
|
||||
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
|
||||
encoder_in_tokens.append(
|
||||
self.encoder_robot_state_input_proj(batch["observation.state"])
|
||||
)
|
||||
# Environment state token.
|
||||
if self.config.env_state_feature:
|
||||
encoder_in_tokens.append(
|
||||
@@ -534,7 +556,9 @@ class ACT(nn.Module):
|
||||
# For a list of images, the H and W may vary but H*W is constant.
|
||||
for img in batch["observation.images"]:
|
||||
cam_features = self.backbone(img)["feature_map"]
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(
|
||||
dtype=cam_features.dtype
|
||||
)
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features)
|
||||
|
||||
# Rearrange features to (sequence, batch, dim).
|
||||
|
||||
@@ -205,11 +205,16 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if len(self.image_features) == 0 and self.env_state_feature is None:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
raise ValueError(
|
||||
"You must provide at least one image or the environment state among the inputs."
|
||||
)
|
||||
|
||||
if self.crop_shape is not None:
|
||||
for key, image_ft in self.image_features.items():
|
||||
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
||||
if (
|
||||
self.crop_shape[0] > image_ft.shape[1]
|
||||
or self.crop_shape[1] > image_ft.shape[2]
|
||||
):
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for "
|
||||
|
||||
@@ -70,7 +70,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
@@ -97,7 +99,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
if self.config.image_features:
|
||||
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
|
||||
if self.config.env_state_feature:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
||||
self._queues["observation.environment_state"] = deque(
|
||||
maxlen=self.config.n_obs_steps
|
||||
)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
@@ -123,7 +127,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
@@ -151,7 +157,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
@@ -515,11 +523,15 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
|
||||
# Note: we have a check in the config class to make sure all images have the same shape.
|
||||
images_shape = next(iter(config.image_features.values())).shape
|
||||
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
dummy_shape_h_w = (
|
||||
config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
)
|
||||
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
|
||||
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
|
||||
|
||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
self.pool = SpatialSoftmax(
|
||||
feature_map_shape, num_kp=config.spatial_softmax_num_keypoints
|
||||
)
|
||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
self.relu = nn.ReLU()
|
||||
@@ -719,7 +731,9 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
)
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
|
||||
DiffusionConv1dBlock(
|
||||
config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size
|
||||
),
|
||||
nn.Conv1d(config.down_dims[0], config.action_feature.shape[0], 1),
|
||||
)
|
||||
|
||||
|
||||
@@ -104,7 +104,9 @@ def make_policy(
|
||||
PreTrainedPolicy: _description_
|
||||
"""
|
||||
if bool(ds_meta) == bool(env_cfg):
|
||||
raise ValueError("Either one of a dataset metadata or a sim env must be provided.")
|
||||
raise ValueError(
|
||||
"Either one of a dataset metadata or a sim env must be provided."
|
||||
)
|
||||
|
||||
# NOTE: Currently, if you try to run vqbet with mps backend, you'll get this error.
|
||||
# TODO(aliberts, rcadene): Implement a check_backend_compatibility in policies?
|
||||
@@ -134,8 +136,12 @@ def make_policy(
|
||||
)
|
||||
features = env_to_policy_features(env_cfg)
|
||||
|
||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||
cfg.output_features = {
|
||||
key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION
|
||||
}
|
||||
cfg.input_features = {
|
||||
key: ft for key, ft in features.items() if key not in cfg.output_features
|
||||
}
|
||||
kwargs["config"] = cfg
|
||||
|
||||
if cfg.pretrained_path:
|
||||
|
||||
@@ -82,25 +82,43 @@ def create_stats_buffers(
|
||||
if stats:
|
||||
if isinstance(stats[key]["mean"], np.ndarray):
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(
|
||||
dtype=torch.float32
|
||||
)
|
||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(
|
||||
dtype=torch.float32
|
||||
)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(
|
||||
dtype=torch.float32
|
||||
)
|
||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(
|
||||
dtype=torch.float32
|
||||
)
|
||||
elif isinstance(stats[key]["mean"], torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
||||
buffer["mean"].data = (
|
||||
stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||
)
|
||||
buffer["std"].data = (
|
||||
stats[key]["std"].clone().to(dtype=torch.float32)
|
||||
)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
||||
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
||||
buffer["min"].data = (
|
||||
stats[key]["min"].clone().to(dtype=torch.float32)
|
||||
)
|
||||
buffer["max"].data = (
|
||||
stats[key]["max"].clone().to(dtype=torch.float32)
|
||||
)
|
||||
else:
|
||||
type_ = type(stats[key]["mean"])
|
||||
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
||||
raise ValueError(
|
||||
f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead."
|
||||
)
|
||||
|
||||
stats_buffers[key] = buffer
|
||||
return stats_buffers
|
||||
|
||||
@@ -44,7 +44,9 @@ def main():
|
||||
else:
|
||||
dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human"
|
||||
|
||||
ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
|
||||
ckpt_torch_dir = (
|
||||
Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
|
||||
)
|
||||
ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}"
|
||||
save_dir = Path(f"../openpi/data/{model_name}/save")
|
||||
|
||||
@@ -70,7 +72,9 @@ def main():
|
||||
# Create LeRobot batch from Jax
|
||||
batch = {}
|
||||
for cam_key, uint_chw_array in example["images"].items():
|
||||
batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
|
||||
batch[f"observation.images.{cam_key}"] = (
|
||||
torch.from_numpy(uint_chw_array) / 255.0
|
||||
)
|
||||
batch["observation.state"] = torch.from_numpy(example["state"])
|
||||
batch["action"] = torch.from_numpy(outputs["actions"])
|
||||
batch["task"] = example["prompt"]
|
||||
|
||||
@@ -54,7 +54,9 @@ def get_paligemma_config(precision: str):
|
||||
"projector_hidden_act": "gelu_fast",
|
||||
"vision_use_head": False,
|
||||
}
|
||||
final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config)
|
||||
final_config = PaliGemmaConfig(
|
||||
text_config=text_config, vision_config=vision_config, **config
|
||||
)
|
||||
return final_config
|
||||
|
||||
|
||||
|
||||
@@ -61,7 +61,11 @@ from lerobot.common.policies.pi0.conversion_scripts.conversion_utils import (
|
||||
)
|
||||
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
PRECISIONS = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
|
||||
PRECISIONS = {
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float32": torch.float32,
|
||||
"float16": torch.float16,
|
||||
}
|
||||
|
||||
|
||||
def slice_paligemma_state_dict(state_dict, config):
|
||||
@@ -318,7 +322,9 @@ def update_keys_with_prefix(d: dict, prefix: str) -> dict:
|
||||
return {f"{prefix}{key}": value for key, value in d.items()}
|
||||
|
||||
|
||||
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str):
|
||||
def convert_pi0_checkpoint(
|
||||
checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str
|
||||
):
|
||||
# Break down orbax ckpts - they are in OCDBT
|
||||
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
|
||||
# process projection params
|
||||
@@ -378,7 +384,9 @@ def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: st
|
||||
# gemma_config=gemma_config, paligemma_config=paligemma_config)
|
||||
pi0_model = PI0Policy(pi0_config)
|
||||
|
||||
paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.")
|
||||
paligemma_params = update_keys_with_prefix(
|
||||
paligemma_params, "model.paligemma_with_expert."
|
||||
)
|
||||
gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.")
|
||||
projection_params = update_keys_with_prefix(projection_params, "model.")
|
||||
|
||||
|
||||
@@ -48,18 +48,32 @@ def flex_attention_forward(
|
||||
|
||||
key_states = key_states[:, :, :, None, :]
|
||||
key_states = key_states.expand(
|
||||
batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
|
||||
batch_size,
|
||||
key_states.shape[1],
|
||||
num_key_value_heads,
|
||||
num_key_value_groups,
|
||||
head_dim,
|
||||
)
|
||||
key_states = key_states.reshape(
|
||||
batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
|
||||
batch_size,
|
||||
key_states.shape[1],
|
||||
num_key_value_heads * num_key_value_groups,
|
||||
head_dim,
|
||||
)
|
||||
|
||||
value_states = value_states[:, :, :, None, :]
|
||||
value_states = value_states.expand(
|
||||
batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
|
||||
batch_size,
|
||||
value_states.shape[1],
|
||||
num_key_value_heads,
|
||||
num_key_value_groups,
|
||||
head_dim,
|
||||
)
|
||||
value_states = value_states.reshape(
|
||||
batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
|
||||
batch_size,
|
||||
value_states.shape[1],
|
||||
num_key_value_heads * num_key_value_groups,
|
||||
head_dim,
|
||||
)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
|
||||
@@ -69,7 +69,11 @@ from lerobot.common.utils.utils import get_safe_dtype
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
time: torch.tensor,
|
||||
dimension: int,
|
||||
min_period: float,
|
||||
max_period: float,
|
||||
device="cpu",
|
||||
) -> Tensor:
|
||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
if dimension % 2 != 0:
|
||||
@@ -189,7 +193,9 @@ def aloha_gripper_to_angular(value):
|
||||
|
||||
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
||||
def linear_to_radian(linear_position, arm_length, horn_radius):
|
||||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
|
||||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (
|
||||
2 * horn_radius * linear_position
|
||||
)
|
||||
return safe_arcsin(value)
|
||||
|
||||
# The constants are taken from the Interbotix code.
|
||||
@@ -240,7 +246,9 @@ class PI0Policy(PreTrainedPolicy):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
@@ -248,7 +256,9 @@ class PI0Policy(PreTrainedPolicy):
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
self.language_tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google/paligemma-3b-pt-224"
|
||||
)
|
||||
self.model = PI0FlowMatching(config)
|
||||
|
||||
self.reset()
|
||||
@@ -261,7 +271,9 @@ class PI0Policy(PreTrainedPolicy):
|
||||
return self.parameters()
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
def select_action(
|
||||
self, batch: dict[str, Tensor], noise: Tensor | None = None
|
||||
) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
@@ -300,7 +312,9 @@ class PI0Policy(PreTrainedPolicy):
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
|
||||
def forward(
|
||||
self, batch: dict[str, Tensor], noise=None, time=None
|
||||
) -> tuple[Tensor, dict[str, Tensor]]:
|
||||
"""Do a full training forward pass to compute the loss"""
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
@@ -316,7 +330,9 @@ class PI0Policy(PreTrainedPolicy):
|
||||
actions_is_pad = batch.get("action_is_pad")
|
||||
|
||||
loss_dict = {}
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||
losses = self.model.forward(
|
||||
images, img_masks, lang_tokens, lang_masks, state, actions, noise, time
|
||||
)
|
||||
loss_dict["losses_after_forward"] = losses.clone()
|
||||
|
||||
if actions_is_pad is not None:
|
||||
@@ -343,7 +359,9 @@ class PI0Policy(PreTrainedPolicy):
|
||||
img_masks = []
|
||||
|
||||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||
missing_img_keys = [key for key in self.config.image_features if key not in batch]
|
||||
missing_img_keys = [
|
||||
key for key in self.config.image_features if key not in batch
|
||||
]
|
||||
|
||||
if len(present_img_keys) == 0:
|
||||
raise ValueError(
|
||||
@@ -355,7 +373,9 @@ class PI0Policy(PreTrainedPolicy):
|
||||
img = batch[key]
|
||||
|
||||
if self.config.resize_imgs_with_padding is not None:
|
||||
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
|
||||
img = resize_with_pad(
|
||||
img, *self.config.resize_imgs_with_padding, pad_value=0
|
||||
)
|
||||
|
||||
# Normalize from range [0,1] to [-1,1] as expacted by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
@@ -394,7 +414,9 @@ class PI0Policy(PreTrainedPolicy):
|
||||
return_tensors="pt",
|
||||
)
|
||||
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
||||
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
|
||||
lang_masks = tokenized_prompt["attention_mask"].to(
|
||||
device=device, dtype=torch.bool
|
||||
)
|
||||
|
||||
return lang_tokens, lang_masks
|
||||
|
||||
@@ -413,7 +435,9 @@ class PI0Policy(PreTrainedPolicy):
|
||||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular(
|
||||
actions[:, :, motor_idx]
|
||||
)
|
||||
return actions
|
||||
|
||||
def _pi_aloha_encode_actions_inv(self, actions):
|
||||
@@ -422,7 +446,9 @@ class PI0Policy(PreTrainedPolicy):
|
||||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(
|
||||
actions[:, :, motor_idx]
|
||||
)
|
||||
return actions
|
||||
|
||||
def prepare_state(self, batch):
|
||||
@@ -472,15 +498,25 @@ class PI0FlowMatching(nn.Module):
|
||||
train_expert_only=self.config.train_expert_only,
|
||||
attention_implementation=self.config.attention_implementation,
|
||||
)
|
||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
|
||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
||||
paligemma_with_export_config
|
||||
)
|
||||
|
||||
# Projections are float32
|
||||
self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
|
||||
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width)
|
||||
self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim)
|
||||
self.action_in_proj = nn.Linear(
|
||||
self.config.max_action_dim, self.config.proj_width
|
||||
)
|
||||
self.action_out_proj = nn.Linear(
|
||||
self.config.proj_width, self.config.max_action_dim
|
||||
)
|
||||
|
||||
self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width)
|
||||
self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
|
||||
self.action_time_mlp_in = nn.Linear(
|
||||
self.config.proj_width * 2, self.config.proj_width
|
||||
)
|
||||
self.action_time_mlp_out = nn.Linear(
|
||||
self.config.proj_width, self.config.proj_width
|
||||
)
|
||||
|
||||
self.set_requires_grad()
|
||||
|
||||
@@ -524,7 +560,9 @@ class PI0FlowMatching(nn.Module):
|
||||
|
||||
# Normalize image embeddings
|
||||
img_emb_dim = img_emb.shape[-1]
|
||||
img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
|
||||
img_emb = img_emb * torch.tensor(
|
||||
img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device
|
||||
)
|
||||
|
||||
bsize, num_img_embs = img_emb.shape[:2]
|
||||
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
|
||||
@@ -577,7 +615,11 @@ class PI0FlowMatching(nn.Module):
|
||||
|
||||
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
||||
time_emb = create_sinusoidal_pos_embedding(
|
||||
timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device
|
||||
timestep,
|
||||
self.config.proj_width,
|
||||
min_period=4e-3,
|
||||
max_period=4.0,
|
||||
device=device,
|
||||
)
|
||||
time_emb = time_emb.type(dtype=dtype)
|
||||
|
||||
@@ -595,7 +637,9 @@ class PI0FlowMatching(nn.Module):
|
||||
embs.append(action_time_emb)
|
||||
|
||||
bsize, action_time_dim = action_time_emb.shape[:2]
|
||||
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
|
||||
action_time_mask = torch.ones(
|
||||
bsize, action_time_dim, dtype=torch.bool, device=device
|
||||
)
|
||||
pad_masks.append(action_time_mask)
|
||||
|
||||
# Set attention masks so that image, language and state inputs do not attend to action tokens
|
||||
@@ -609,7 +653,15 @@ class PI0FlowMatching(nn.Module):
|
||||
return embs, pad_masks, att_masks
|
||||
|
||||
def forward(
|
||||
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
lang_tokens,
|
||||
lang_masks,
|
||||
state,
|
||||
actions,
|
||||
noise=None,
|
||||
time=None,
|
||||
) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
|
||||
if noise is None:
|
||||
@@ -625,7 +677,9 @@ class PI0FlowMatching(nn.Module):
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks
|
||||
)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
|
||||
state, x_t, time
|
||||
)
|
||||
|
||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
||||
@@ -649,13 +703,19 @@ class PI0FlowMatching(nn.Module):
|
||||
losses = F.mse_loss(u_t, v_t, reduction="none")
|
||||
return losses
|
||||
|
||||
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
|
||||
def sample_actions(
|
||||
self, images, img_masks, lang_tokens, lang_masks, state, noise=None
|
||||
) -> Tensor:
|
||||
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
||||
bsize = state.shape[0]
|
||||
device = state.device
|
||||
|
||||
if noise is None:
|
||||
actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim)
|
||||
actions_shape = (
|
||||
bsize,
|
||||
self.config.n_action_steps,
|
||||
self.config.max_action_dim,
|
||||
)
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
@@ -703,12 +763,16 @@ class PI0FlowMatching(nn.Module):
|
||||
timestep,
|
||||
):
|
||||
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
|
||||
state, x_t, timestep
|
||||
)
|
||||
|
||||
suffix_len = suffix_pad_masks.shape[1]
|
||||
batch_size = prefix_pad_masks.shape[0]
|
||||
prefix_len = prefix_pad_masks.shape[1]
|
||||
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
|
||||
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(
|
||||
batch_size, suffix_len, prefix_len
|
||||
)
|
||||
|
||||
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
||||
|
||||
|
||||
@@ -39,9 +39,13 @@ def apply_rope(x, positions, max_wavelength=10_000):
|
||||
dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
|
||||
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
|
||||
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(
|
||||
d_half, dtype=torch.float32, device=device
|
||||
)
|
||||
timescale = max_wavelength**freq_exponents
|
||||
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
|
||||
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(
|
||||
torch.float32
|
||||
)
|
||||
|
||||
radians = radians[..., None, :]
|
||||
|
||||
@@ -174,7 +178,9 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
def __init__(self, config: PaliGemmaWithExpertConfig):
|
||||
super().__init__(config=config)
|
||||
self.config = config
|
||||
self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
|
||||
self.paligemma = PaliGemmaForConditionalGeneration(
|
||||
config=config.paligemma_config
|
||||
)
|
||||
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
|
||||
# Remove unused embed_tokens
|
||||
self.gemma_expert.model.embed_tokens = None
|
||||
@@ -291,14 +297,22 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
|
||||
# the max len, then we (for instance) double the cache size. This implementation already exists
|
||||
# in `transformers`. (molbap)
|
||||
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
|
||||
key_states = torch.cat(
|
||||
[past_key_values[layer_idx]["key_states"], key_states], dim=1
|
||||
)
|
||||
value_states = torch.cat(
|
||||
[past_key_values[layer_idx]["value_states"], value_states], dim=1
|
||||
[past_key_values[layer_idx]["value_states"], value_states],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
attention_interface = self.get_attention_interface()
|
||||
att_output = attention_interface(
|
||||
attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
)
|
||||
att_output = att_output.to(dtype=torch.bfloat16)
|
||||
|
||||
@@ -358,15 +372,29 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
return attention_interface
|
||||
|
||||
def flash_attention_forward(
|
||||
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
self,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
):
|
||||
raise NotImplementedError("FA2 is not implemented (yet)")
|
||||
|
||||
def eager_attention_forward(
|
||||
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
self,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
):
|
||||
num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
|
||||
num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
|
||||
num_key_value_heads = (
|
||||
self.config.paligemma_config.text_config.num_key_value_heads
|
||||
)
|
||||
num_key_value_groups = num_att_heads // num_key_value_heads
|
||||
|
||||
# query_states: batch_size, sequence_length, num_att_head, head_dim
|
||||
@@ -375,17 +403,31 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
sequence_length = key_states.shape[1]
|
||||
|
||||
key_states = key_states[:, :, :, None, :].expand(
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
batch_size,
|
||||
sequence_length,
|
||||
num_key_value_heads,
|
||||
num_key_value_groups,
|
||||
head_dim,
|
||||
)
|
||||
key_states = key_states.reshape(
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
batch_size,
|
||||
sequence_length,
|
||||
num_key_value_heads * num_key_value_groups,
|
||||
head_dim,
|
||||
)
|
||||
|
||||
value_states = value_states[:, :, :, None, :].expand(
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
batch_size,
|
||||
sequence_length,
|
||||
num_key_value_heads,
|
||||
num_key_value_groups,
|
||||
head_dim,
|
||||
)
|
||||
value_states = value_states.reshape(
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
batch_size,
|
||||
sequence_length,
|
||||
num_key_value_heads * num_key_value_groups,
|
||||
head_dim,
|
||||
)
|
||||
|
||||
# Attention here is upcasted to float32 to match the original eager implementation.
|
||||
@@ -400,7 +442,9 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
att_weights *= head_dim**-0.5
|
||||
big_neg = -2.3819763e38 # See gemma/modules.py
|
||||
|
||||
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
|
||||
masked_att_weights = torch.where(
|
||||
attention_mask[:, None, :, :], att_weights, big_neg
|
||||
)
|
||||
|
||||
probs = nn.functional.softmax(masked_att_weights, dim=-1)
|
||||
probs = probs.to(dtype=value_states.dtype)
|
||||
@@ -412,6 +456,8 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
|
||||
att_output = att_output.permute(0, 2, 1, 3)
|
||||
# we use -1 because sequence length can change
|
||||
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
|
||||
att_output = att_output.reshape(
|
||||
batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim
|
||||
)
|
||||
|
||||
return att_output
|
||||
|
||||
@@ -71,7 +71,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
self.config._save_pretrained(save_directory)
|
||||
model_to_save = self.module if hasattr(self, "module") else self
|
||||
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
|
||||
save_model_as_safetensor(
|
||||
model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
@@ -110,7 +112,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
if os.path.isdir(model_id):
|
||||
print("Loading weights from local directory")
|
||||
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
||||
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
|
||||
policy = cls._load_as_safetensor(
|
||||
instance, model_file, config.device, strict
|
||||
)
|
||||
else:
|
||||
try:
|
||||
model_file = hf_hub_download(
|
||||
@@ -124,7 +128,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
|
||||
policy = cls._load_as_safetensor(
|
||||
instance, model_file, config.device, strict
|
||||
)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
|
||||
@@ -135,8 +141,12 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
return policy
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
||||
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"):
|
||||
def _load_as_safetensor(
|
||||
cls, model: T, model_file: str, map_location: str, strict: bool
|
||||
) -> T:
|
||||
if packaging.version.parse(safetensors.__version__) < packaging.version.parse(
|
||||
"0.4.3"
|
||||
):
|
||||
load_model_as_safetensor(model, model_file, strict=strict)
|
||||
if map_location != "cpu":
|
||||
logging.warning(
|
||||
@@ -147,7 +157,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
)
|
||||
model.to(map_location)
|
||||
else:
|
||||
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
|
||||
safetensors.torch.load_model(
|
||||
model, model_file, strict=strict, device=map_location
|
||||
)
|
||||
return model
|
||||
|
||||
# def generate_model_card(self, *args, **kwargs) -> ModelCard:
|
||||
|
||||
@@ -639,9 +639,9 @@ class Policy(nn.Module):
|
||||
# Compute standard deviations
|
||||
if self.fixed_std is None:
|
||||
log_std = self.std_layer(outputs)
|
||||
assert not torch.isnan(
|
||||
log_std
|
||||
).any(), "[ERROR] log_std became NaN after std_layer!"
|
||||
assert not torch.isnan(log_std).any(), (
|
||||
"[ERROR] log_std became NaN after std_layer!"
|
||||
)
|
||||
|
||||
if self.use_tanh_squash:
|
||||
log_std = torch.tanh(log_std)
|
||||
|
||||
@@ -187,7 +187,9 @@ class TDMPCConfig(PreTrainedConfig):
|
||||
"If `n_action_steps > 1`, `use_mpc` must be set to `True`."
|
||||
)
|
||||
if self.n_action_steps > self.horizon:
|
||||
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")
|
||||
raise ValueError(
|
||||
"`n_action_steps` must be less than or equal to `horizon`."
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamConfig:
|
||||
return AdamConfig(lr=self.optimizer_lr)
|
||||
@@ -207,7 +209,9 @@ class TDMPCConfig(PreTrainedConfig):
|
||||
if image_ft.shape[-2] != image_ft.shape[-1]:
|
||||
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
||||
# augmentation. It should be able to be removed.
|
||||
raise ValueError(f"Only square images are handled now. Got image shape {image_ft.shape}.")
|
||||
raise ValueError(
|
||||
f"Only square images are handled now. Got image shape {image_ft.shape}."
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
|
||||
@@ -39,7 +39,11 @@ from lerobot.common.constants import OBS_ENV, OBS_ROBOT
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||
from lerobot.common.policies.utils import (
|
||||
get_device_from_parameters,
|
||||
get_output_shape,
|
||||
populate_queues,
|
||||
)
|
||||
|
||||
|
||||
class TDMPCPolicy(PreTrainedPolicy):
|
||||
@@ -63,7 +67,11 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
config_class = TDMPCConfig
|
||||
name = "tdmpc"
|
||||
|
||||
def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: TDMPCConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
@@ -75,7 +83,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
@@ -117,7 +127,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
"""Select a single action given environment observations."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[next(iter(self.config.image_features))]
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
@@ -201,7 +213,10 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
# algorithm.
|
||||
# The initial mean and standard deviation for the cross-entropy method (CEM).
|
||||
mean = torch.zeros(
|
||||
self.config.horizon, batch_size, self.config.action_feature.shape[0], device=device
|
||||
self.config.horizon,
|
||||
batch_size,
|
||||
self.config.action_feature.shape[0],
|
||||
device=device,
|
||||
)
|
||||
# Maybe warm start CEM with the mean from the previous step.
|
||||
if self._prev_mean is not None:
|
||||
@@ -339,7 +354,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[next(iter(self.config.image_features))]
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
@@ -371,7 +388,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
current_observation[k] = observations[k][0]
|
||||
next_observations[k] = observations[k][1:]
|
||||
horizon, batch_size = next_observations[
|
||||
"observation.image" if self.config.image_features else "observation.environment_state"
|
||||
"observation.image"
|
||||
if self.config.image_features
|
||||
else "observation.environment_state"
|
||||
].shape[:2]
|
||||
|
||||
# Run latent rollout using the latent dynamics model and policy model.
|
||||
@@ -569,7 +588,9 @@ class TDMPCTOLD(nn.Module):
|
||||
self.config = config
|
||||
self._encoder = TDMPCObservationEncoder(config)
|
||||
self._dynamics = nn.Sequential(
|
||||
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
|
||||
nn.Linear(
|
||||
config.latent_dim + config.action_feature.shape[0], config.mlp_dim
|
||||
),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
@@ -580,7 +601,9 @@ class TDMPCTOLD(nn.Module):
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
self._reward = nn.Sequential(
|
||||
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
|
||||
nn.Linear(
|
||||
config.latent_dim + config.action_feature.shape[0], config.mlp_dim
|
||||
),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
@@ -600,7 +623,10 @@ class TDMPCTOLD(nn.Module):
|
||||
self._Qs = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
|
||||
nn.Linear(
|
||||
config.latent_dim + config.action_feature.shape[0],
|
||||
config.mlp_dim,
|
||||
),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Tanh(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
@@ -786,7 +812,9 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
|
||||
if config.robot_state_feature:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
nn.Linear(config.robot_state_feature.shape[0], config.state_encoder_hidden_dim),
|
||||
nn.Linear(
|
||||
config.robot_state_feature.shape[0], config.state_encoder_hidden_dim
|
||||
),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
@@ -795,7 +823,9 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
|
||||
if config.env_state_feature:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(config.env_state_feature.shape[0], config.state_encoder_hidden_dim),
|
||||
nn.Linear(
|
||||
config.env_state_feature.shape[0], config.state_encoder_hidden_dim
|
||||
),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
@@ -813,7 +843,8 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
if self.config.image_features:
|
||||
feat.append(
|
||||
flatten_forward_unflatten(
|
||||
self.image_enc_layers, obs_dict[next(iter(self.config.image_features))]
|
||||
self.image_enc_layers,
|
||||
obs_dict[next(iter(self.config.image_features))],
|
||||
)
|
||||
)
|
||||
if self.config.env_state_feature:
|
||||
|
||||
@@ -172,7 +172,10 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
|
||||
if self.crop_shape is not None:
|
||||
for key, image_ft in self.image_features.items():
|
||||
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
||||
if (
|
||||
self.crop_shape[0] > image_ft.shape[1]
|
||||
or self.crop_shape[1] > image_ft.shape[2]
|
||||
):
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for "
|
||||
@@ -193,7 +196,12 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, self.n_action_pred_token + self.action_chunk_size - 1))
|
||||
return list(
|
||||
range(
|
||||
1 - self.n_obs_steps,
|
||||
self.n_action_pred_token + self.action_chunk_size - 1,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
|
||||
@@ -29,7 +29,11 @@ from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||
from lerobot.common.policies.utils import (
|
||||
get_device_from_parameters,
|
||||
get_output_shape,
|
||||
populate_queues,
|
||||
)
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.common.policies.vqbet.vqbet_utils import GPT, ResidualVQ
|
||||
|
||||
@@ -60,7 +64,9 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
@@ -91,11 +97,17 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
if self.config.sequentially_select:
|
||||
decay_params = (
|
||||
decay_params
|
||||
+ list(self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
|
||||
+ list(self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
|
||||
+ list(
|
||||
self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters()
|
||||
)
|
||||
+ list(
|
||||
self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters()
|
||||
)
|
||||
)
|
||||
else:
|
||||
decay_params = decay_params + list(self.vqbet.action_head.map_to_cbet_preds_bin.parameters())
|
||||
decay_params = decay_params + list(
|
||||
self.vqbet.action_head.map_to_cbet_preds_bin.parameters()
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
@@ -133,8 +145,12 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
# Note: It's important that this happens after stacking the images into a single key.
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -165,8 +181,12 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
batch = self.normalize_targets(batch)
|
||||
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
|
||||
if not self.vqbet.action_head.vqvae_model.discretized.item():
|
||||
@@ -334,7 +354,8 @@ class VQBeTModel(nn.Module):
|
||||
|
||||
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
|
||||
self.state_projector = MLP(
|
||||
config.robot_state_feature.shape[0], hidden_channels=[self.config.gpt_input_dim]
|
||||
config.robot_state_feature.shape[0],
|
||||
hidden_channels=[self.config.gpt_input_dim],
|
||||
)
|
||||
self.rgb_feature_projector = MLP(
|
||||
self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]
|
||||
@@ -406,9 +427,9 @@ class VQBeTModel(nn.Module):
|
||||
features = self.policy(input_tokens)
|
||||
# len(self.config.input_features) is the number of different observation modes.
|
||||
# this line gets the index of action prompt tokens.
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_features) + 1) + len(
|
||||
self.config.input_features
|
||||
)
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * (
|
||||
len(self.config.input_features) + 1
|
||||
) + len(self.config.input_features)
|
||||
|
||||
# only extract the output tokens at the position of action query:
|
||||
# Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models,
|
||||
@@ -771,11 +792,15 @@ class VQBeTRgbEncoder(nn.Module):
|
||||
# height and width from `config.image_features`.
|
||||
|
||||
images_shape = next(iter(config.image_features.values())).shape
|
||||
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
dummy_shape_h_w = (
|
||||
config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
)
|
||||
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
|
||||
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
|
||||
|
||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
self.pool = SpatialSoftmax(
|
||||
feature_map_shape, num_kp=config.spatial_softmax_num_keypoints
|
||||
)
|
||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
self.relu = nn.ReLU()
|
||||
@@ -871,7 +896,8 @@ class VqVae(nn.Module):
|
||||
)
|
||||
|
||||
self.encoder = MLP(
|
||||
in_channels=self.config.action_feature.shape[0] * self.config.action_chunk_size,
|
||||
in_channels=self.config.action_feature.shape[0]
|
||||
* self.config.action_chunk_size,
|
||||
hidden_channels=[
|
||||
config.vqvae_enc_hidden_dim,
|
||||
config.vqvae_enc_hidden_dim,
|
||||
@@ -899,9 +925,13 @@ class VqVae(nn.Module):
|
||||
# given latent vector, this function outputs the decoded action.
|
||||
output = self.decoder(latent)
|
||||
if self.config.action_chunk_size == 1:
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
|
||||
return einops.rearrange(
|
||||
output, "N (T A) -> N T A", A=self.config.action_feature.shape[0]
|
||||
)
|
||||
else:
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
|
||||
return einops.rearrange(
|
||||
output, "N (T A) -> N T A", A=self.config.action_feature.shape[0]
|
||||
)
|
||||
|
||||
def get_code(self, state):
|
||||
# in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181)
|
||||
|
||||
@@ -290,10 +290,10 @@ class GPT(nn.Module):
|
||||
param_dict = dict(self.named_parameters())
|
||||
inter_params = decay & no_decay
|
||||
union_params = decay | no_decay
|
||||
assert (
|
||||
len(inter_params) == 0
|
||||
), "parameters {} made it into both decay/no_decay sets!".format(
|
||||
str(inter_params)
|
||||
assert len(inter_params) == 0, (
|
||||
"parameters {} made it into both decay/no_decay sets!".format(
|
||||
str(inter_params)
|
||||
)
|
||||
)
|
||||
assert len(param_dict.keys() - union_params) == 0, (
|
||||
"parameters {} were not separated into either decay/no_decay set!".format(
|
||||
@@ -664,14 +664,14 @@ class VectorQuantize(nn.Module):
|
||||
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
|
||||
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
|
||||
|
||||
assert not (
|
||||
ema_update and learnable_codebook
|
||||
), "learnable codebook not compatible with EMA update"
|
||||
assert not (ema_update and learnable_codebook), (
|
||||
"learnable codebook not compatible with EMA update"
|
||||
)
|
||||
|
||||
assert 0 <= sync_update_v <= 1.0
|
||||
assert not (
|
||||
sync_update_v > 0.0 and not learnable_codebook
|
||||
), "learnable codebook must be turned on"
|
||||
assert not (sync_update_v > 0.0 and not learnable_codebook), (
|
||||
"learnable codebook must be turned on"
|
||||
)
|
||||
|
||||
self.sync_update_v = sync_update_v
|
||||
|
||||
|
||||
@@ -57,7 +57,9 @@ class OpenCVCameraConfig(CameraConfig):
|
||||
self.channels = 3
|
||||
|
||||
if self.rotation not in [-90, None, 90, 180]:
|
||||
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
|
||||
raise ValueError(
|
||||
f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})"
|
||||
)
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("intelrealsense")
|
||||
@@ -102,8 +104,12 @@ class IntelRealSenseCameraConfig(CameraConfig):
|
||||
|
||||
self.channels = 3
|
||||
|
||||
at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None
|
||||
at_least_one_is_none = self.fps is None or self.width is None or self.height is None
|
||||
at_least_one_is_not_none = (
|
||||
self.fps is not None or self.width is not None or self.height is not None
|
||||
)
|
||||
at_least_one_is_none = (
|
||||
self.fps is None or self.width is None or self.height is None
|
||||
)
|
||||
if at_least_one_is_not_none and at_least_one_is_none:
|
||||
raise ValueError(
|
||||
"For `fps`, `width` and `height`, either all of them need to be set, or none of them, "
|
||||
@@ -111,4 +117,6 @@ class IntelRealSenseCameraConfig(CameraConfig):
|
||||
)
|
||||
|
||||
if self.rotation not in [-90, None, 90, 180]:
|
||||
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
|
||||
raise ValueError(
|
||||
f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})"
|
||||
)
|
||||
|
||||
@@ -303,7 +303,11 @@ class IntelRealSenseCamera:
|
||||
if self.fps and self.capture_width and self.capture_height:
|
||||
# TODO(rcadene): can we set rgb8 directly?
|
||||
config.enable_stream(
|
||||
rs.stream.color, self.capture_width, self.capture_height, rs.format.rgb8, self.fps
|
||||
rs.stream.color,
|
||||
self.capture_width,
|
||||
self.capture_height,
|
||||
rs.format.rgb8,
|
||||
self.fps,
|
||||
)
|
||||
else:
|
||||
config.enable_stream(rs.stream.color)
|
||||
@@ -311,7 +315,11 @@ class IntelRealSenseCamera:
|
||||
if self.use_depth:
|
||||
if self.fps and self.capture_width and self.capture_height:
|
||||
config.enable_stream(
|
||||
rs.stream.depth, self.capture_width, self.capture_height, rs.format.z16, self.fps
|
||||
rs.stream.depth,
|
||||
self.capture_width,
|
||||
self.capture_height,
|
||||
rs.format.z16,
|
||||
self.fps,
|
||||
)
|
||||
else:
|
||||
config.enable_stream(rs.stream.depth)
|
||||
|
||||
@@ -144,7 +144,9 @@ def save_images_from_cameras(
|
||||
print("Connecting cameras")
|
||||
cameras = []
|
||||
for cam_idx in camera_ids:
|
||||
config = OpenCVCameraConfig(camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock)
|
||||
config = OpenCVCameraConfig(
|
||||
camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock
|
||||
)
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect()
|
||||
print(
|
||||
@@ -250,7 +252,9 @@ class OpenCVCamera:
|
||||
# Retrieve the camera index from a potentially symlinked path
|
||||
self.camera_index = get_camera_index_from_unix_port(self.port)
|
||||
else:
|
||||
raise ValueError(f"Please check the provided camera_index: {self.camera_index}")
|
||||
raise ValueError(
|
||||
f"Please check the provided camera_index: {self.camera_index}"
|
||||
)
|
||||
|
||||
# Store the raw (capture) resolution from the config.
|
||||
self.capture_width = config.width
|
||||
@@ -314,7 +318,11 @@ class OpenCVCamera:
|
||||
else cv2.CAP_ANY
|
||||
)
|
||||
|
||||
camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index
|
||||
camera_idx = (
|
||||
f"/dev/video{self.camera_index}"
|
||||
if platform.system() == "Linux"
|
||||
else self.camera_index
|
||||
)
|
||||
# First create a temporary camera trying to access `camera_index`,
|
||||
# and verify it is a valid camera by calling `isOpened`.
|
||||
tmp_camera = cv2.VideoCapture(camera_idx, backend)
|
||||
|
||||
@@ -41,7 +41,9 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> list[C
|
||||
cameras[key] = OpenCVCamera(cfg)
|
||||
|
||||
elif cfg.type == "intelrealsense":
|
||||
from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera
|
||||
from lerobot.common.robot_devices.cameras.intelrealsense import (
|
||||
IntelRealSenseCamera,
|
||||
)
|
||||
|
||||
cameras[key] = IntelRealSenseCamera(cfg)
|
||||
else:
|
||||
@@ -58,7 +60,9 @@ def make_camera(camera_type, **kwargs) -> Camera:
|
||||
return OpenCVCamera(config)
|
||||
|
||||
elif camera_type == "intelrealsense":
|
||||
from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera
|
||||
from lerobot.common.robot_devices.cameras.intelrealsense import (
|
||||
IntelRealSenseCamera,
|
||||
)
|
||||
|
||||
config = IntelRealSenseCameraConfig(**kwargs)
|
||||
return IntelRealSenseCamera(config)
|
||||
|
||||
@@ -93,7 +93,9 @@ class RecordControlConfig(ControlConfig):
|
||||
policy_path = parser.get_path_arg("control.policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("control.policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy = PreTrainedConfig.from_pretrained(
|
||||
policy_path, cli_overrides=cli_overrides
|
||||
)
|
||||
self.policy.pretrained_path = policy_path
|
||||
|
||||
|
||||
|
||||
@@ -282,7 +282,10 @@ def control_loop(
|
||||
|
||||
if policy is not None:
|
||||
pred_action = predict_action(
|
||||
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
||||
observation,
|
||||
policy,
|
||||
get_safe_torch_device(policy.config.device),
|
||||
policy.config.use_amp,
|
||||
)
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset.
|
||||
|
||||
@@ -23,7 +23,10 @@ import numpy as np
|
||||
import tqdm
|
||||
|
||||
from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
)
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
||||
PROTOCOL_VERSION = 2.0
|
||||
|
||||
@@ -23,7 +23,10 @@ import numpy as np
|
||||
import tqdm
|
||||
|
||||
from lerobot.common.robot_devices.motors.configs import FeetechMotorsBusConfig
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
)
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
||||
PROTOCOL_VERSION = 0
|
||||
|
||||
@@ -30,7 +30,9 @@ class MotorsBus(Protocol):
|
||||
def write(self): ...
|
||||
|
||||
|
||||
def make_motors_buses_from_configs(motors_bus_configs: dict[str, MotorsBusConfig]) -> list[MotorsBus]:
|
||||
def make_motors_buses_from_configs(
|
||||
motors_bus_configs: dict[str, MotorsBusConfig],
|
||||
) -> list[MotorsBus]:
|
||||
motors_buses = {}
|
||||
|
||||
for key, cfg in motors_bus_configs.items():
|
||||
|
||||
@@ -69,9 +69,13 @@ class ManipulatorRobotConfig(RobotConfig):
|
||||
if not cam.mock:
|
||||
cam.mock = True
|
||||
|
||||
if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence):
|
||||
if self.max_relative_target is not None and isinstance(
|
||||
self.max_relative_target, Sequence
|
||||
):
|
||||
for name in self.follower_arms:
|
||||
if len(self.follower_arms[name].motors) != len(self.max_relative_target):
|
||||
if len(self.follower_arms[name].motors) != len(
|
||||
self.max_relative_target
|
||||
):
|
||||
raise ValueError(
|
||||
f"len(max_relative_target)={len(self.max_relative_target)} but the follower arm with name {name} has "
|
||||
f"{len(self.follower_arms[name].motors)} motors. Please make sure that the "
|
||||
|
||||
@@ -42,7 +42,9 @@ def run_camera_capture(cameras, images_lock, latest_images_dict, stop_event):
|
||||
local_dict = {}
|
||||
for name, cam in cameras.items():
|
||||
frame = cam.async_read()
|
||||
ret, buffer = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
|
||||
ret, buffer = cv2.imencode(
|
||||
".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90]
|
||||
)
|
||||
if ret:
|
||||
local_dict[name] = base64.b64encode(buffer).decode("utf-8")
|
||||
else:
|
||||
@@ -61,7 +63,9 @@ def calibrate_follower_arm(motors_bus, calib_dir_str):
|
||||
calib_dir.mkdir(parents=True, exist_ok=True)
|
||||
calib_file = calib_dir / "main_follower.json"
|
||||
try:
|
||||
from lerobot.common.robot_devices.robots.feetech_calibration import run_arm_manual_calibration
|
||||
from lerobot.common.robot_devices.robots.feetech_calibration import (
|
||||
run_arm_manual_calibration,
|
||||
)
|
||||
except ImportError:
|
||||
print("[WARNING] Calibration function not available. Skipping calibration.")
|
||||
return
|
||||
@@ -72,7 +76,9 @@ def calibrate_follower_arm(motors_bus, calib_dir_str):
|
||||
print(f"[INFO] Loaded calibration from {calib_file}")
|
||||
else:
|
||||
print("[INFO] Calibration file not found. Running manual calibration...")
|
||||
calibration = run_arm_manual_calibration(motors_bus, "lekiwi", "follower_arm", "follower")
|
||||
calibration = run_arm_manual_calibration(
|
||||
motors_bus, "lekiwi", "follower_arm", "follower"
|
||||
)
|
||||
print(f"[INFO] Calibration complete. Saving to {calib_file}")
|
||||
with open(calib_file, "w") as f:
|
||||
json.dump(calibration, f)
|
||||
@@ -116,7 +122,14 @@ def run_lekiwi(robot_config):
|
||||
robot = LeKiwi(motors_bus)
|
||||
|
||||
# Define the expected arm motor IDs.
|
||||
arm_motor_ids = ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"]
|
||||
arm_motor_ids = [
|
||||
"shoulder_pan",
|
||||
"shoulder_lift",
|
||||
"elbow_flex",
|
||||
"wrist_flex",
|
||||
"wrist_roll",
|
||||
"gripper",
|
||||
]
|
||||
|
||||
# Disable torque for each arm motor.
|
||||
for motor in arm_motor_ids:
|
||||
@@ -130,7 +143,9 @@ def run_lekiwi(robot_config):
|
||||
images_lock = threading.Lock()
|
||||
stop_event = threading.Event()
|
||||
cam_thread = threading.Thread(
|
||||
target=run_camera_capture, args=(cameras, images_lock, latest_images_dict, stop_event), daemon=True
|
||||
target=run_camera_capture,
|
||||
args=(cameras, images_lock, latest_images_dict, stop_event),
|
||||
daemon=True,
|
||||
)
|
||||
cam_thread.start()
|
||||
|
||||
@@ -159,7 +174,9 @@ def run_lekiwi(robot_config):
|
||||
f"[WARNING] Received {len(arm_positions)} arm positions, expected {len(arm_motor_ids)}"
|
||||
)
|
||||
else:
|
||||
for motor, pos in zip(arm_motor_ids, arm_positions, strict=False):
|
||||
for motor, pos in zip(
|
||||
arm_motor_ids, arm_positions, strict=False
|
||||
):
|
||||
motors_bus.write("Goal_Position", pos, motor)
|
||||
# Process wheel (base) commands.
|
||||
if "raw_velocity" in data:
|
||||
@@ -190,7 +207,9 @@ def run_lekiwi(robot_config):
|
||||
try:
|
||||
pos = motors_bus.read("Present_Position", motor)
|
||||
# Convert the position to a float (or use as is if already numeric).
|
||||
follower_arm_state.append(float(pos) if not isinstance(pos, (int, float)) else pos)
|
||||
follower_arm_state.append(
|
||||
float(pos) if not isinstance(pos, (int, float)) else pos
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Reading motor {motor} failed: {e}")
|
||||
|
||||
|
||||
@@ -28,7 +28,10 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
|
||||
from lerobot.common.robot_devices.motors.utils import (
|
||||
MotorsBus,
|
||||
make_motors_buses_from_configs,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig
|
||||
from lerobot.common.robot_devices.robots.utils import get_arm_id
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
|
||||
@@ -25,9 +25,14 @@ import zmq
|
||||
|
||||
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.common.robot_devices.motors.feetech import TorqueMode
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
|
||||
from lerobot.common.robot_devices.motors.utils import (
|
||||
MotorsBus,
|
||||
make_motors_buses_from_configs,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.configs import LeKiwiRobotConfig
|
||||
from lerobot.common.robot_devices.robots.feetech_calibration import run_arm_manual_calibration
|
||||
from lerobot.common.robot_devices.robots.feetech_calibration import (
|
||||
run_arm_manual_calibration,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import get_arm_id
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceNotConnectedError
|
||||
|
||||
@@ -266,7 +271,9 @@ class MobileManipulator:
|
||||
calibration = json.load(f)
|
||||
else:
|
||||
print(f"Missing calibration file '{arm_calib_path}'")
|
||||
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
|
||||
calibration = run_arm_manual_calibration(
|
||||
arm, self.robot_type, name, arm_type
|
||||
)
|
||||
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
||||
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(arm_calib_path, "w") as f:
|
||||
@@ -296,7 +303,9 @@ class MobileManipulator:
|
||||
bus.write("Torque_Enable", 0, motor_id)
|
||||
|
||||
# Then filter out wheels
|
||||
arm_only_dict = {k: v for k, v in bus.motors.items() if not k.startswith("wheel_")}
|
||||
arm_only_dict = {
|
||||
k: v for k, v in bus.motors.items() if not k.startswith("wheel_")
|
||||
}
|
||||
if not arm_only_dict:
|
||||
continue
|
||||
|
||||
@@ -324,7 +333,11 @@ class MobileManipulator:
|
||||
socks = dict(poller.poll(15))
|
||||
if self.video_socket not in socks or socks[self.video_socket] != zmq.POLLIN:
|
||||
# No new data arrived → reuse ALL old data
|
||||
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
||||
return (
|
||||
self.last_frames,
|
||||
self.last_present_speed,
|
||||
self.last_remote_arm_state,
|
||||
)
|
||||
|
||||
# Drain all messages, keep only the last
|
||||
last_msg = None
|
||||
@@ -337,7 +350,11 @@ class MobileManipulator:
|
||||
|
||||
if not last_msg:
|
||||
# No new message → also reuse old
|
||||
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
||||
return (
|
||||
self.last_frames,
|
||||
self.last_present_speed,
|
||||
self.last_remote_arm_state,
|
||||
)
|
||||
|
||||
# Decode only the final message
|
||||
try:
|
||||
@@ -360,7 +377,9 @@ class MobileManipulator:
|
||||
if new_arm_state is not None and frames is not None:
|
||||
self.last_frames = frames
|
||||
|
||||
remote_arm_state_tensor = torch.tensor(new_arm_state, dtype=torch.float32)
|
||||
remote_arm_state_tensor = torch.tensor(
|
||||
new_arm_state, dtype=torch.float32
|
||||
)
|
||||
self.last_remote_arm_state = remote_arm_state_tensor
|
||||
|
||||
present_speed = new_speed
|
||||
@@ -375,14 +394,21 @@ class MobileManipulator:
|
||||
except Exception as e:
|
||||
print(f"[DEBUG] Error decoding video message: {e}")
|
||||
# If decode fails, fall back to old data
|
||||
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
||||
return (
|
||||
self.last_frames,
|
||||
self.last_present_speed,
|
||||
self.last_remote_arm_state,
|
||||
)
|
||||
|
||||
return frames, present_speed, remote_arm_state_tensor
|
||||
|
||||
def _process_present_speed(self, present_speed: dict) -> torch.Tensor:
|
||||
state_tensor = torch.zeros(3, dtype=torch.int32)
|
||||
if present_speed:
|
||||
decoded = {key: MobileManipulator.raw_to_degps(value) for key, value in present_speed.items()}
|
||||
decoded = {
|
||||
key: MobileManipulator.raw_to_degps(value)
|
||||
for key, value in present_speed.items()
|
||||
}
|
||||
if "1" in decoded:
|
||||
state_tensor[0] = decoded["1"]
|
||||
if "2" in decoded:
|
||||
@@ -395,7 +421,9 @@ class MobileManipulator:
|
||||
self, record_data: bool = False
|
||||
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError("MobileManipulator is not connected. Run `connect()` first.")
|
||||
raise RobotDeviceNotConnectedError(
|
||||
"MobileManipulator is not connected. Run `connect()` first."
|
||||
)
|
||||
|
||||
speed_setting = self.speed_levels[self.speed_index]
|
||||
xy_speed = speed_setting["xy"] # e.g. 0.1, 0.25, or 0.4
|
||||
@@ -461,9 +489,15 @@ class MobileManipulator:
|
||||
|
||||
body_state = self.wheel_raw_to_body(present_speed)
|
||||
|
||||
body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s
|
||||
body_state_mm = (
|
||||
body_state[0] * 1000.0,
|
||||
body_state[1] * 1000.0,
|
||||
body_state[2],
|
||||
) # Convert x,y to mm/s
|
||||
wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32)
|
||||
combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0)
|
||||
combined_state_tensor = torch.cat(
|
||||
(remote_arm_state_tensor, wheel_state_tensor), dim=0
|
||||
)
|
||||
|
||||
obs_dict = {"observation.state": combined_state_tensor}
|
||||
|
||||
@@ -620,7 +654,11 @@ class MobileManipulator:
|
||||
# Convert each wheel’s angular speed (deg/s) to a raw integer.
|
||||
wheel_raw = [MobileManipulator.degps_to_raw(deg) for deg in wheel_degps]
|
||||
|
||||
return {"left_wheel": wheel_raw[0], "back_wheel": wheel_raw[1], "right_wheel": wheel_raw[2]}
|
||||
return {
|
||||
"left_wheel": wheel_raw[0],
|
||||
"back_wheel": wheel_raw[1],
|
||||
"right_wheel": wheel_raw[2],
|
||||
}
|
||||
|
||||
def wheel_raw_to_body(
|
||||
self, wheel_raw: dict, wheel_radius: float = 0.05, base_radius: float = 0.125
|
||||
|
||||
@@ -72,7 +72,9 @@ def make_robot_from_config(config: RobotConfig):
|
||||
|
||||
return ManipulatorRobot(config)
|
||||
elif isinstance(config, LeKiwiRobotConfig):
|
||||
from lerobot.common.robot_devices.robots.mobile_manipulator import MobileManipulator
|
||||
from lerobot.common.robot_devices.robots.mobile_manipulator import (
|
||||
MobileManipulator,
|
||||
)
|
||||
|
||||
return MobileManipulator(config)
|
||||
else:
|
||||
|
||||
@@ -69,7 +69,9 @@ class HubMixin:
|
||||
if push_to_hub:
|
||||
if repo_id is None:
|
||||
repo_id = save_directory.name # Defaults to `save_directory` name
|
||||
return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs)
|
||||
return self.push_to_hub(
|
||||
repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs
|
||||
)
|
||||
return None
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
@@ -175,7 +177,9 @@ class HubMixin:
|
||||
The url of the commit of your object in the given repository.
|
||||
"""
|
||||
api = HfApi(token=token)
|
||||
repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
|
||||
repo_id = api.create_repo(
|
||||
repo_id=repo_id, private=private, exist_ok=True
|
||||
).repo_id
|
||||
|
||||
if commit_message is None:
|
||||
if "Policy" in self.__class__.__name__:
|
||||
|
||||
@@ -20,7 +20,16 @@ from typing import TypeVar
|
||||
|
||||
import imageio
|
||||
|
||||
JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...]
|
||||
JsonLike = (
|
||||
str
|
||||
| int
|
||||
| float
|
||||
| bool
|
||||
| None
|
||||
| list["JsonLike"]
|
||||
| dict[str, "JsonLike"]
|
||||
| tuple["JsonLike", ...]
|
||||
)
|
||||
T = TypeVar("T", bound=JsonLike)
|
||||
|
||||
|
||||
@@ -76,7 +85,9 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
|
||||
|
||||
# Check length
|
||||
if len(target) != len(source):
|
||||
raise ValueError(f"List length mismatch: expected {len(target)}, got {len(source)}")
|
||||
raise ValueError(
|
||||
f"List length mismatch: expected {len(target)}, got {len(source)}"
|
||||
)
|
||||
|
||||
# Recursively update each element.
|
||||
for i in range(len(target)):
|
||||
@@ -88,10 +99,14 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
|
||||
# which we'll convert back to a tuple.
|
||||
elif isinstance(target, tuple):
|
||||
if not isinstance(source, list):
|
||||
raise TypeError(f"Type mismatch: expected list (for tuple), got {type(source)}")
|
||||
raise TypeError(
|
||||
f"Type mismatch: expected list (for tuple), got {type(source)}"
|
||||
)
|
||||
|
||||
if len(target) != len(source):
|
||||
raise ValueError(f"Tuple length mismatch: expected {len(target)}, got {len(source)}")
|
||||
raise ValueError(
|
||||
f"Tuple length mismatch: expected {len(target)}, got {len(source)}"
|
||||
)
|
||||
|
||||
# Convert each element, forming a new tuple.
|
||||
converted_items = []
|
||||
@@ -105,7 +120,9 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
|
||||
else:
|
||||
# Check the exact type. If these must match 1:1, do:
|
||||
if type(target) is not type(source):
|
||||
raise TypeError(f"Type mismatch: expected {type(target)}, got {type(source)}")
|
||||
raise TypeError(
|
||||
f"Type mismatch: expected {type(target)}, got {type(source)}"
|
||||
)
|
||||
return source
|
||||
|
||||
# Perform the in-place/recursive deserialization
|
||||
|
||||
@@ -107,13 +107,17 @@ class MetricsTracker:
|
||||
self.episodes = self.samples / self._avg_samples_per_ep
|
||||
self.epochs = self.samples / self._num_frames
|
||||
|
||||
def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any:
|
||||
def __getattr__(
|
||||
self, name: str
|
||||
) -> int | dict[str, AverageMeter] | AverageMeter | Any:
|
||||
if name in self.__dict__:
|
||||
return self.__dict__[name]
|
||||
elif name in self.metrics:
|
||||
return self.metrics[name]
|
||||
else:
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
||||
)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name in self.__dict__:
|
||||
@@ -121,7 +125,9 @@ class MetricsTracker:
|
||||
elif name in self.metrics:
|
||||
self.metrics[name].update(value)
|
||||
else:
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
||||
)
|
||||
|
||||
def step(self) -> None:
|
||||
"""
|
||||
|
||||
@@ -42,7 +42,11 @@ def deserialize_python_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> Non
|
||||
"""
|
||||
Restores the rng state for `random` from a dictionary produced by `serialize_python_rng_state()`.
|
||||
"""
|
||||
py_state = (rng_state_dict["py_rng_version"].item(), tuple(rng_state_dict["py_rng_state"].tolist()), None)
|
||||
py_state = (
|
||||
rng_state_dict["py_rng_version"].item(),
|
||||
tuple(rng_state_dict["py_rng_state"].tolist()),
|
||||
None,
|
||||
)
|
||||
random.setstate(py_state)
|
||||
|
||||
|
||||
@@ -119,7 +123,9 @@ def deserialize_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
py_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("py")}
|
||||
np_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("np")}
|
||||
torch_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("torch")}
|
||||
torch_rng_state_dict = {
|
||||
k: v for k, v in rng_state_dict.items() if k.startswith("torch")
|
||||
}
|
||||
|
||||
deserialize_python_rng_state(py_rng_state_dict)
|
||||
deserialize_numpy_rng_state(np_rng_state_dict)
|
||||
|
||||
@@ -48,7 +48,9 @@ def auto_select_torch_device() -> torch.device:
|
||||
logging.info("Metal backend detected, using cuda.")
|
||||
return torch.device("mps")
|
||||
else:
|
||||
logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
|
||||
logging.warning(
|
||||
"No accelerated backend detected. Using default cpu, this will be slow."
|
||||
)
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
@@ -96,7 +98,9 @@ def is_torch_device_available(try_device: str) -> bool:
|
||||
elif try_device == "cpu":
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
|
||||
raise ValueError(
|
||||
f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu."
|
||||
)
|
||||
|
||||
|
||||
def is_amp_available(device: str):
|
||||
@@ -219,7 +223,9 @@ def say(text, blocking=False):
|
||||
if blocking:
|
||||
subprocess.run(cmd, check=True)
|
||||
else:
|
||||
subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0)
|
||||
subprocess.Popen(
|
||||
cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0
|
||||
)
|
||||
|
||||
|
||||
def log_say(text, play_sounds, blocking=False):
|
||||
|
||||
@@ -26,7 +26,9 @@ from lerobot.common.constants import PRETRAINED_MODEL_DIR
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
|
||||
def cfg_to_group(
|
||||
cfg: TrainPipelineConfig, return_list: bool = False
|
||||
) -> list[str] | str:
|
||||
"""Return a group name for logging. Optionally returns group name as list."""
|
||||
lst = [
|
||||
f"policy:{cfg.policy.type}",
|
||||
@@ -93,7 +95,9 @@ class WandBLogger:
|
||||
mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online",
|
||||
)
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
logging.info(
|
||||
f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}"
|
||||
)
|
||||
self._wandb = wandb
|
||||
|
||||
def log_policy(self, checkpoint_dir: Path):
|
||||
@@ -105,7 +109,9 @@ class WandBLogger:
|
||||
artifact_name = f"{self._group}-{step_id}"
|
||||
artifact_name = get_safe_wandb_artifact_name(artifact_name)
|
||||
artifact = self._wandb.Artifact(artifact_name, type="model")
|
||||
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
|
||||
artifact.add_file(
|
||||
checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE
|
||||
)
|
||||
self._wandb.log_artifact(artifact)
|
||||
|
||||
def log_dict(self, d: dict, step: int, mode: str = "train"):
|
||||
|
||||
Reference in New Issue
Block a user