[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Michel Aractingi
parent
2abbd60a0d
commit
0ea27704f6
@@ -46,18 +46,14 @@ def sample_indices(data_len: int) -> list[int]:
|
||||
return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist()
|
||||
|
||||
|
||||
def auto_downsample_height_width(
|
||||
img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300
|
||||
):
|
||||
def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300):
|
||||
_, height, width = img.shape
|
||||
|
||||
if max(width, height) < max_size_threshold:
|
||||
# no downsampling needed
|
||||
return img
|
||||
|
||||
downsample_factor = (
|
||||
int(width / target_size) if width > height else int(height / target_size)
|
||||
)
|
||||
downsample_factor = int(width / target_size) if width > height else int(height / target_size)
|
||||
return img[:, ::downsample_factor, ::downsample_factor]
|
||||
|
||||
|
||||
@@ -79,9 +75,7 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
|
||||
return images
|
||||
|
||||
|
||||
def get_feature_stats(
|
||||
array: np.ndarray, axis: tuple, keepdims: bool
|
||||
) -> dict[str, np.ndarray]:
|
||||
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
|
||||
return {
|
||||
"min": np.min(array, axis=axis, keepdims=keepdims),
|
||||
"max": np.max(array, axis=axis, keepdims=keepdims),
|
||||
@@ -91,9 +85,7 @@ def get_feature_stats(
|
||||
}
|
||||
|
||||
|
||||
def compute_episode_stats(
|
||||
episode_data: dict[str, list[str] | np.ndarray], features: dict
|
||||
) -> dict:
|
||||
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
|
||||
ep_stats = {}
|
||||
for key, data in episode_data.items():
|
||||
if features[key]["dtype"] == "string":
|
||||
@@ -107,15 +99,12 @@ def compute_episode_stats(
|
||||
axes_to_reduce = 0 # compute stats over the first axis
|
||||
keepdims = data.ndim == 1 # keep as np.array
|
||||
|
||||
ep_stats[key] = get_feature_stats(
|
||||
ep_ft_array, axis=axes_to_reduce, keepdims=keepdims
|
||||
)
|
||||
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
|
||||
|
||||
# finally, we normalize and remove batch dim for images
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v / 255.0, axis=0)
|
||||
for k, v in ep_stats[key].items()
|
||||
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
return ep_stats
|
||||
@@ -130,17 +119,11 @@ def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
||||
f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
|
||||
)
|
||||
if v.ndim == 0:
|
||||
raise ValueError(
|
||||
"Number of dimensions must be at least 1, and is 0 instead."
|
||||
)
|
||||
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
|
||||
if k == "count" and v.shape != (1,):
|
||||
raise ValueError(
|
||||
f"Shape of 'count' must be (1), but is {v.shape} instead."
|
||||
)
|
||||
raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
|
||||
if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
|
||||
raise ValueError(
|
||||
f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead."
|
||||
)
|
||||
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
|
||||
|
||||
|
||||
def aggregate_feature_stats(
|
||||
|
||||
@@ -58,9 +58,7 @@ def resolve_delta_timestamps(
|
||||
if key == "action" and cfg.action_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
|
||||
if key.startswith("observation.") and cfg.observation_delta_indices is not None:
|
||||
delta_timestamps[key] = [
|
||||
i / ds_meta.fps for i in cfg.observation_delta_indices
|
||||
]
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
|
||||
|
||||
if len(delta_timestamps) == 0:
|
||||
delta_timestamps = None
|
||||
@@ -81,9 +79,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
LeRobotDataset | MultiLeRobotDataset
|
||||
"""
|
||||
image_transforms = (
|
||||
ImageTransforms(cfg.dataset.image_transforms)
|
||||
if cfg.dataset.image_transforms.enable
|
||||
else None
|
||||
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
|
||||
)
|
||||
|
||||
if isinstance(cfg.dataset.repo_id, str):
|
||||
@@ -117,8 +113,6 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
if cfg.dataset.use_imagenet_stats:
|
||||
for key in dataset.meta.camera_keys:
|
||||
for stats_type, stats in IMAGENET_STATS.items():
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(
|
||||
stats, dtype=torch.float32
|
||||
)
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -38,14 +38,10 @@ def safe_stop_image_writer(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def image_array_to_pil_image(
|
||||
image_array: np.ndarray, range_check: bool = True
|
||||
) -> PIL.Image.Image:
|
||||
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
|
||||
# TODO(aliberts): handle 1 channel and 4 for depth images
|
||||
if image_array.ndim != 3:
|
||||
raise ValueError(
|
||||
f"The array has {image_array.ndim} dimensions, but 3 is expected for an image."
|
||||
)
|
||||
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
|
||||
|
||||
if image_array.shape[0] == 3:
|
||||
# Transpose from pytorch convention (C, H, W) to (H, W, C)
|
||||
@@ -131,9 +127,7 @@ class AsyncImageWriter:
|
||||
self._stopped = False
|
||||
|
||||
if num_threads <= 0 and num_processes <= 0:
|
||||
raise ValueError(
|
||||
"Number of threads and processes must be greater than zero."
|
||||
)
|
||||
raise ValueError("Number of threads and processes must be greater than zero.")
|
||||
|
||||
if self.num_processes == 0:
|
||||
# Use threading
|
||||
@@ -147,16 +141,12 @@ class AsyncImageWriter:
|
||||
# Use multiprocessing
|
||||
self.queue = multiprocessing.JoinableQueue()
|
||||
for _ in range(self.num_processes):
|
||||
p = multiprocessing.Process(
|
||||
target=worker_process, args=(self.queue, self.num_threads)
|
||||
)
|
||||
p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads))
|
||||
p.daemon = True
|
||||
p.start()
|
||||
self.processes.append(p)
|
||||
|
||||
def save_image(
|
||||
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path
|
||||
):
|
||||
def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path):
|
||||
if isinstance(image, torch.Tensor):
|
||||
# Convert tensor to numpy array to minimize main process time
|
||||
image = image.cpu().numpy()
|
||||
|
||||
@@ -108,9 +108,7 @@ class LeRobotDatasetMetadata:
|
||||
self.episodes = load_episodes(self.root)
|
||||
if self._version < packaging.version.parse("v2.1"):
|
||||
self.stats = load_stats(self.root)
|
||||
self.episodes_stats = backward_compatible_episodes_stats(
|
||||
self.stats, self.episodes
|
||||
)
|
||||
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
|
||||
else:
|
||||
self.episodes_stats = load_episodes_stats(self.root)
|
||||
self.stats = aggregate_stats(list(self.episodes_stats.values()))
|
||||
@@ -141,9 +139,7 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
||||
ep_chunk = self.get_episode_chunk(ep_index)
|
||||
fpath = self.video_path.format(
|
||||
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index
|
||||
)
|
||||
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
||||
return Path(fpath)
|
||||
|
||||
def get_episode_chunk(self, ep_index: int) -> int:
|
||||
@@ -187,11 +183,7 @@ class LeRobotDatasetMetadata:
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
return [
|
||||
key
|
||||
for key, ft in self.features.items()
|
||||
if ft["dtype"] in ["video", "image"]
|
||||
]
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
@@ -240,9 +232,7 @@ class LeRobotDatasetMetadata:
|
||||
Given a task in natural language, add it to the dictionary of tasks.
|
||||
"""
|
||||
if task in self.task_to_task_index:
|
||||
raise ValueError(
|
||||
f"The task '{task}' already exists and can't be added twice."
|
||||
)
|
||||
raise ValueError(f"The task '{task}' already exists and can't be added twice.")
|
||||
|
||||
task_index = self.info["total_tasks"]
|
||||
self.task_to_task_index[task] = task_index
|
||||
@@ -285,11 +275,7 @@ class LeRobotDatasetMetadata:
|
||||
write_episode(episode_dict, self.root)
|
||||
|
||||
self.episodes_stats[episode_index] = episode_stats
|
||||
self.stats = (
|
||||
aggregate_stats([self.stats, episode_stats])
|
||||
if self.stats
|
||||
else episode_stats
|
||||
)
|
||||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
|
||||
write_episode_stats(episode_index, episode_stats, self.root)
|
||||
|
||||
def update_video_info(self) -> None:
|
||||
@@ -299,9 +285,7 @@ class LeRobotDatasetMetadata:
|
||||
"""
|
||||
for key in self.video_keys:
|
||||
if not self.features[key].get("info", None):
|
||||
video_path = self.root / self.get_video_file_path(
|
||||
ep_index=0, vid_key=key
|
||||
)
|
||||
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||
|
||||
def __repr__(self):
|
||||
@@ -353,17 +337,13 @@ class LeRobotDatasetMetadata:
|
||||
# as this would break the dict flattening in the stats computation, which uses '/' as separator
|
||||
for key in features:
|
||||
if "/" in key:
|
||||
raise ValueError(
|
||||
f"Feature names should not contain '/'. Found '/' in feature '{key}'."
|
||||
)
|
||||
raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.")
|
||||
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
|
||||
obj.tasks, obj.task_to_task_index = {}, {}
|
||||
obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
|
||||
obj.info = create_empty_dataset_info(
|
||||
CODEBASE_VERSION, fps, robot_type, features, use_videos
|
||||
)
|
||||
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
|
||||
if len(obj.video_keys) > 0 and not use_videos:
|
||||
raise ValueError()
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
@@ -494,9 +474,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.video_backend = (
|
||||
video_backend if video_backend else get_safe_default_codec()
|
||||
)
|
||||
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
self.delta_indices = None
|
||||
|
||||
# Unused attributes
|
||||
@@ -509,39 +487,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.meta = LeRobotDatasetMetadata(
|
||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
if self.episodes is not None and self.meta._version >= packaging.version.parse(
|
||||
"v2.1"
|
||||
):
|
||||
episodes_stats = [
|
||||
self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes
|
||||
]
|
||||
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
|
||||
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
||||
self.stats = aggregate_stats(episodes_stats)
|
||||
|
||||
# Load actual data
|
||||
try:
|
||||
if force_cache_sync:
|
||||
raise FileNotFoundError
|
||||
assert all(
|
||||
(self.root / fpath).is_file()
|
||||
for fpath in self.get_episodes_file_paths()
|
||||
)
|
||||
assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
self.download_episodes(download_videos)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
|
||||
self.episode_data_index = get_episode_data_index(
|
||||
self.meta.episodes, self.episodes
|
||||
)
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
|
||||
# Check timestamps
|
||||
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
|
||||
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
|
||||
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
||||
check_timestamps_sync(
|
||||
timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s
|
||||
)
|
||||
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
||||
|
||||
# Setup delta_indices
|
||||
if self.delta_timestamps is not None:
|
||||
@@ -593,9 +560,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
hub_api.upload_folder(**upload_kwargs)
|
||||
|
||||
if not hub_api.file_exists(
|
||||
self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch
|
||||
):
|
||||
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
||||
)
|
||||
@@ -603,12 +568,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
if tag_version:
|
||||
with contextlib.suppress(RevisionNotFoundError):
|
||||
hub_api.delete_tag(
|
||||
self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset"
|
||||
)
|
||||
hub_api.create_tag(
|
||||
self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset"
|
||||
)
|
||||
hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
||||
hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
|
||||
def pull_from_repo(
|
||||
self,
|
||||
@@ -640,11 +601,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
|
||||
def get_episodes_file_paths(self) -> list[Path]:
|
||||
episodes = (
|
||||
self.episodes
|
||||
if self.episodes is not None
|
||||
else list(range(self.meta.total_episodes))
|
||||
)
|
||||
episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes))
|
||||
fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
|
||||
if len(self.meta.video_keys) > 0:
|
||||
video_files = [
|
||||
@@ -662,10 +619,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
path = str(self.root / "data")
|
||||
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
|
||||
else:
|
||||
files = [
|
||||
str(self.root / self.meta.get_data_file_path(ep_idx))
|
||||
for ep_idx in self.episodes
|
||||
]
|
||||
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
||||
hf_dataset = load_dataset("parquet", data_files=files, split="train")
|
||||
|
||||
# TODO(aliberts): hf_dataset.set_format("torch")
|
||||
@@ -675,9 +629,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def create_hf_dataset(self) -> datasets.Dataset:
|
||||
features = get_hf_features_from_features(self.features)
|
||||
ft_dict = {col: [] for col in features}
|
||||
hf_dataset = datasets.Dataset.from_dict(
|
||||
ft_dict, features=features, split="train"
|
||||
)
|
||||
hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
|
||||
|
||||
# TODO(aliberts): hf_dataset.set_format("torch")
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
@@ -691,20 +643,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
"""Number of frames in selected episodes."""
|
||||
return (
|
||||
len(self.hf_dataset)
|
||||
if self.hf_dataset is not None
|
||||
else self.meta.total_frames
|
||||
)
|
||||
return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
"""Number of episodes selected."""
|
||||
return (
|
||||
len(self.episodes)
|
||||
if self.episodes is not None
|
||||
else self.meta.total_episodes
|
||||
)
|
||||
return len(self.episodes) if self.episodes is not None else self.meta.total_episodes
|
||||
|
||||
@property
|
||||
def features(self) -> dict[str, dict]:
|
||||
@@ -718,24 +662,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
return get_hf_features_from_features(self.features)
|
||||
|
||||
def _get_query_indices(
|
||||
self, idx: int, ep_idx: int
|
||||
) -> tuple[dict[str, list[int | bool]]]:
|
||||
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
||||
ep_start = self.episode_data_index["from"][ep_idx]
|
||||
ep_end = self.episode_data_index["to"][ep_idx]
|
||||
query_indices = {
|
||||
key: [
|
||||
max(ep_start.item(), min(ep_end.item() - 1, idx + delta))
|
||||
for delta in delta_idx
|
||||
]
|
||||
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
padding = { # Pad values outside of current episode range
|
||||
f"{key}_is_pad": torch.BoolTensor(
|
||||
[
|
||||
(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item())
|
||||
for delta in delta_idx
|
||||
]
|
||||
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
|
||||
)
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
@@ -763,9 +699,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if key not in self.meta.video_keys
|
||||
}
|
||||
|
||||
def _query_videos(
|
||||
self, query_timestamps: dict[str, list[float]], ep_idx: int
|
||||
) -> dict[str, torch.Tensor]:
|
||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
|
||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
|
||||
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
||||
@@ -774,9 +708,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
item = {}
|
||||
for vid_key, query_ts in query_timestamps.items():
|
||||
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||
frames = decode_video_frames(
|
||||
video_path, query_ts, self.tolerance_s, self.video_backend
|
||||
)
|
||||
frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
|
||||
item[vid_key] = frames.squeeze(0)
|
||||
|
||||
return item
|
||||
@@ -830,9 +762,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
|
||||
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||
current_ep_idx = (
|
||||
self.meta.total_episodes if episode_index is None else episode_index
|
||||
)
|
||||
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
|
||||
ep_buffer = {}
|
||||
# size and task are special cases that are not in self.features
|
||||
ep_buffer["size"] = 0
|
||||
@@ -841,17 +771,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
|
||||
return ep_buffer
|
||||
|
||||
def _get_image_file_path(
|
||||
self, episode_index: int, image_key: str, frame_index: int
|
||||
) -> Path:
|
||||
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
return self.root / fpath
|
||||
|
||||
def _save_image(
|
||||
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path
|
||||
) -> None:
|
||||
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
|
||||
if self.image_writer is None:
|
||||
if isinstance(image, torch.Tensor):
|
||||
image = image.cpu().numpy()
|
||||
@@ -877,9 +803,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
# Automatically add frame_index and timestamp to episode buffer
|
||||
frame_index = self.episode_buffer["size"]
|
||||
timestamp = (
|
||||
frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
|
||||
)
|
||||
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
|
||||
self.episode_buffer["frame_index"].append(frame_index)
|
||||
self.episode_buffer["timestamp"].append(timestamp)
|
||||
|
||||
@@ -930,9 +854,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episode_tasks = list(set(tasks))
|
||||
episode_index = episode_buffer["episode_index"]
|
||||
|
||||
episode_buffer["index"] = np.arange(
|
||||
self.meta.total_frames, self.meta.total_frames + episode_length
|
||||
)
|
||||
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
|
||||
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
|
||||
|
||||
# Add new tasks to the tasks dictionary
|
||||
@@ -942,9 +864,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.meta.add_task(task)
|
||||
|
||||
# Given tasks in natural language, find their corresponding task indices
|
||||
episode_buffer["task_index"] = np.array(
|
||||
[self.meta.get_task_index(task) for task in tasks]
|
||||
)
|
||||
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
|
||||
|
||||
for key, ft in self.features.items():
|
||||
# index, episode_index, task_index are already processed above, and image and video
|
||||
@@ -994,9 +914,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
|
||||
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
|
||||
ep_dataset = datasets.Dataset.from_dict(
|
||||
episode_dict, features=self.hf_features, split="train"
|
||||
)
|
||||
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
|
||||
ep_dataset = embed_images(ep_dataset)
|
||||
self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset])
|
||||
self.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
@@ -1115,9 +1033,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.delta_timestamps = None
|
||||
obj.delta_indices = None
|
||||
obj.episode_data_index = None
|
||||
obj.video_backend = (
|
||||
video_backend if video_backend is not None else get_safe_default_codec()
|
||||
)
|
||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
return obj
|
||||
|
||||
|
||||
@@ -1142,9 +1058,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME
|
||||
self.tolerances_s = (
|
||||
tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
|
||||
)
|
||||
self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
|
||||
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
||||
# are handled by this class.
|
||||
self._datasets = [
|
||||
@@ -1223,13 +1137,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
def features(self) -> datasets.Features:
|
||||
features = {}
|
||||
for dataset in self._datasets:
|
||||
features.update(
|
||||
{
|
||||
k: v
|
||||
for k, v in dataset.hf_features.items()
|
||||
if k not in self.disabled_features
|
||||
}
|
||||
)
|
||||
features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
|
||||
return features
|
||||
|
||||
@property
|
||||
@@ -1290,9 +1198,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
continue
|
||||
break
|
||||
else:
|
||||
raise AssertionError(
|
||||
"We expect the loop to break out as long as the index is within bounds."
|
||||
)
|
||||
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
|
||||
item = self._datasets[dataset_idx][idx - start_idx]
|
||||
item["dataset_index"] = torch.tensor(dataset_idx)
|
||||
for data_key in self.disabled_features:
|
||||
|
||||
@@ -131,9 +131,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
else:
|
||||
self._delta_timestamps = None
|
||||
|
||||
def _make_data_spec(
|
||||
self, data_spec: dict[str, Any], buffer_capacity: int
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]:
|
||||
"""Makes the data spec for np.memmap."""
|
||||
if any(k.startswith("_") for k in data_spec):
|
||||
raise ValueError(
|
||||
@@ -208,9 +206,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
|
||||
# Shift the incoming indices if necessary.
|
||||
if self.num_frames > 0:
|
||||
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][
|
||||
next_index - 1
|
||||
]
|
||||
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
|
||||
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
|
||||
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
|
||||
data[OnlineBuffer.INDEX_KEY] += last_data_index + 1
|
||||
@@ -245,11 +241,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(
|
||||
np.unique(
|
||||
self._data[OnlineBuffer.EPISODE_INDEX_KEY][
|
||||
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]
|
||||
]
|
||||
)
|
||||
np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -287,9 +279,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY],
|
||||
)
|
||||
)[0]
|
||||
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][
|
||||
episode_data_indices
|
||||
]
|
||||
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices]
|
||||
|
||||
for data_key in self.delta_timestamps:
|
||||
# Note: The logic in this loop is copied from `load_previous_and_future_frames`.
|
||||
@@ -306,8 +296,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
|
||||
# Check violated query timestamps are all outside the episode range.
|
||||
assert (
|
||||
(query_ts[is_pad] < episode_timestamps[0])
|
||||
| (episode_timestamps[-1] < query_ts[is_pad])
|
||||
(query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad])
|
||||
).all(), (
|
||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}"
|
||||
") inside the episode range."
|
||||
@@ -322,9 +311,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
|
||||
def get_data_by_key(self, key: str) -> torch.Tensor:
|
||||
"""Returns all data for a given data key as a Tensor."""
|
||||
return torch.from_numpy(
|
||||
self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]
|
||||
)
|
||||
return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
|
||||
|
||||
|
||||
def compute_sampler_weights(
|
||||
@@ -355,19 +342,13 @@ def compute_sampler_weights(
|
||||
- Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not
|
||||
included here to avoid adding complexity.
|
||||
"""
|
||||
if len(offline_dataset) == 0 and (
|
||||
online_dataset is None or len(online_dataset) == 0
|
||||
):
|
||||
raise ValueError(
|
||||
"At least one of `offline_dataset` or `online_dataset` should be contain data."
|
||||
)
|
||||
if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0):
|
||||
raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.")
|
||||
if (online_dataset is None) ^ (online_sampling_ratio is None):
|
||||
raise ValueError(
|
||||
"`online_dataset` and `online_sampling_ratio` must be provided together or not at all."
|
||||
)
|
||||
offline_sampling_ratio = (
|
||||
0 if online_sampling_ratio is None else 1 - online_sampling_ratio
|
||||
)
|
||||
offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio
|
||||
|
||||
weights = []
|
||||
|
||||
|
||||
@@ -45,9 +45,7 @@ def concatenate_episodes(ep_dicts):
|
||||
return data_dict
|
||||
|
||||
|
||||
def save_images_concurrently(
|
||||
imgs_array: numpy.array, out_dir: Path, max_workers: int = 4
|
||||
):
|
||||
def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4):
|
||||
out_dir = Path(out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -57,10 +55,7 @@ def save_images_concurrently(
|
||||
|
||||
num_images = len(imgs_array)
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
[
|
||||
executor.submit(save_image, imgs_array[i], i, out_dir)
|
||||
for i in range(num_images)
|
||||
]
|
||||
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
|
||||
|
||||
|
||||
def get_default_encoding() -> dict:
|
||||
@@ -69,8 +64,7 @@ def get_default_encoding() -> dict:
|
||||
return {
|
||||
k: v.default
|
||||
for k, v in signature.parameters.items()
|
||||
if v.default is not inspect.Parameter.empty
|
||||
and k in ["vcodec", "pix_fmt", "g", "crf"]
|
||||
if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"]
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -58,9 +58,7 @@ class RandomSubsetApply(Transform):
|
||||
elif not isinstance(n_subset, int):
|
||||
raise TypeError("n_subset should be an int or None")
|
||||
elif not (1 <= n_subset <= len(transforms)):
|
||||
raise ValueError(
|
||||
f"n_subset should be in the interval [1, {len(transforms)}]"
|
||||
)
|
||||
raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]")
|
||||
|
||||
self.transforms = transforms
|
||||
total = sum(p)
|
||||
@@ -121,36 +119,26 @@ class SharpnessJitter(Transform):
|
||||
def _check_input(self, sharpness):
|
||||
if isinstance(sharpness, (int, float)):
|
||||
if sharpness < 0:
|
||||
raise ValueError(
|
||||
"If sharpness is a single number, it must be non negative."
|
||||
)
|
||||
raise ValueError("If sharpness is a single number, it must be non negative.")
|
||||
sharpness = [1.0 - sharpness, 1.0 + sharpness]
|
||||
sharpness[0] = max(sharpness[0], 0.0)
|
||||
elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2:
|
||||
sharpness = [float(v) for v in sharpness]
|
||||
else:
|
||||
raise TypeError(
|
||||
f"{sharpness=} should be a single number or a sequence with length 2."
|
||||
)
|
||||
raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.")
|
||||
|
||||
if not 0.0 <= sharpness[0] <= sharpness[1]:
|
||||
raise ValueError(
|
||||
f"sharpnesss values should be between (0., inf), but got {sharpness}."
|
||||
)
|
||||
raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.")
|
||||
|
||||
return float(sharpness[0]), float(sharpness[1])
|
||||
|
||||
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
|
||||
sharpness_factor = (
|
||||
torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
|
||||
)
|
||||
sharpness_factor = torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
|
||||
return {"sharpness_factor": sharpness_factor}
|
||||
|
||||
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
|
||||
sharpness_factor = params["sharpness_factor"]
|
||||
return self._call_kernel(
|
||||
F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor
|
||||
)
|
||||
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -52,15 +52,9 @@ STATS_PATH = "meta/stats.json"
|
||||
EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
TASKS_PATH = "meta/tasks.jsonl"
|
||||
|
||||
DEFAULT_VIDEO_PATH = (
|
||||
"videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
)
|
||||
DEFAULT_PARQUET_PATH = (
|
||||
"data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||
)
|
||||
DEFAULT_IMAGE_PATH = (
|
||||
"images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||
)
|
||||
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||
|
||||
DATASET_CARD_TEMPLATE = """
|
||||
---
|
||||
@@ -135,9 +129,7 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
||||
elif isinstance(value, (int, float)):
|
||||
serialized_dict[key] = value
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"The value '{value}' of type '{type(value)}' is not supported."
|
||||
)
|
||||
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
|
||||
return unflatten_dict(serialized_dict)
|
||||
|
||||
|
||||
@@ -216,10 +208,7 @@ def write_task(task_index: int, task: dict, local_dir: Path):
|
||||
|
||||
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||
tasks = load_jsonlines(local_dir / TASKS_PATH)
|
||||
tasks = {
|
||||
item["task_index"]: item["task"]
|
||||
for item in sorted(tasks, key=lambda x: x["task_index"])
|
||||
}
|
||||
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
|
||||
return tasks, task_to_task_index
|
||||
|
||||
@@ -230,10 +219,7 @@ def write_episode(episode: dict, local_dir: Path):
|
||||
|
||||
def load_episodes(local_dir: Path) -> dict:
|
||||
episodes = load_jsonlines(local_dir / EPISODES_PATH)
|
||||
return {
|
||||
item["episode_index"]: item
|
||||
for item in sorted(episodes, key=lambda x: x["episode_index"])
|
||||
}
|
||||
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
|
||||
|
||||
|
||||
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
|
||||
@@ -286,9 +272,7 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
elif first_item is None:
|
||||
pass
|
||||
else:
|
||||
items_dict[key] = [
|
||||
x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]
|
||||
]
|
||||
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
||||
return items_dict
|
||||
|
||||
|
||||
@@ -341,9 +325,7 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
|
||||
Otherwise, will throw a `CompatibilityError`.
|
||||
"""
|
||||
target_version = (
|
||||
packaging.version.parse(version)
|
||||
if not isinstance(version, packaging.version.Version)
|
||||
else version
|
||||
packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version
|
||||
)
|
||||
hub_versions = get_repo_versions(repo_id)
|
||||
|
||||
@@ -364,16 +346,12 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
|
||||
return f"v{target_version}"
|
||||
|
||||
compatibles = [
|
||||
v
|
||||
for v in hub_versions
|
||||
if v.major == target_version.major and v.minor <= target_version.minor
|
||||
v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor
|
||||
]
|
||||
if compatibles:
|
||||
return_version = max(compatibles)
|
||||
if return_version < target_version:
|
||||
logging.warning(
|
||||
f"Revision {version} for {repo_id} not found, using version v{return_version}"
|
||||
)
|
||||
logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}")
|
||||
return f"v{return_version}"
|
||||
|
||||
lower_major = [v for v in hub_versions if v.major < target_version.major]
|
||||
@@ -480,9 +458,7 @@ def create_empty_dataset_info(
|
||||
def get_episode_data_index(
|
||||
episode_dicts: dict[dict], episodes: list[int] | None = None
|
||||
) -> dict[str, torch.Tensor]:
|
||||
episode_lengths = {
|
||||
ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()
|
||||
}
|
||||
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
|
||||
if episodes is not None:
|
||||
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
|
||||
|
||||
@@ -532,9 +508,7 @@ def check_timestamps_sync(
|
||||
|
||||
# Mask to ignore differences at the boundaries between episodes
|
||||
mask = np.ones(len(diffs), dtype=bool)
|
||||
ignored_diffs = (
|
||||
episode_data_index["to"][:-1] - 1
|
||||
) # indices at the end of each episode
|
||||
ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode
|
||||
mask[ignored_diffs] = False
|
||||
filtered_within_tolerance = within_tolerance[mask]
|
||||
|
||||
@@ -580,14 +554,10 @@ def check_delta_timestamps(
|
||||
"""
|
||||
outside_tolerance = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
within_tolerance = [
|
||||
abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts
|
||||
]
|
||||
within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts]
|
||||
if not all(within_tolerance):
|
||||
outside_tolerance[key] = [
|
||||
ts
|
||||
for ts, is_within in zip(delta_ts, within_tolerance, strict=True)
|
||||
if not is_within
|
||||
ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within
|
||||
]
|
||||
|
||||
if len(outside_tolerance) > 0:
|
||||
@@ -605,9 +575,7 @@ def check_delta_timestamps(
|
||||
return True
|
||||
|
||||
|
||||
def get_delta_indices(
|
||||
delta_timestamps: dict[str, list[float]], fps: int
|
||||
) -> dict[str, list[int]]:
|
||||
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
|
||||
delta_indices = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
delta_indices[key] = [round(d * fps) for d in delta_ts]
|
||||
@@ -672,9 +640,7 @@ def create_lerobot_dataset_card(
|
||||
],
|
||||
)
|
||||
|
||||
card_template = (
|
||||
importlib.resources.files("lerobot.common.datasets") / "card_template.md"
|
||||
).read_text()
|
||||
card_template = (importlib.resources.files("lerobot.common.datasets") / "card_template.md").read_text()
|
||||
|
||||
return DatasetCard.from_template(
|
||||
card_data=card_data,
|
||||
@@ -743,18 +709,14 @@ def validate_frame(frame: dict, features: dict):
|
||||
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
|
||||
actual_features = set(frame.keys())
|
||||
|
||||
error_message = validate_features_presence(
|
||||
actual_features, expected_features, optional_features
|
||||
)
|
||||
error_message = validate_features_presence(actual_features, expected_features, optional_features)
|
||||
|
||||
if "task" in frame:
|
||||
error_message += validate_feature_string("task", frame["task"])
|
||||
|
||||
common_features = actual_features & (expected_features | optional_features)
|
||||
for name in common_features - {"task"}:
|
||||
error_message += validate_feature_dtype_and_shape(
|
||||
name, features[name], frame[name]
|
||||
)
|
||||
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
|
||||
|
||||
if error_message:
|
||||
raise ValueError(error_message)
|
||||
@@ -777,9 +739,7 @@ def validate_features_presence(
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_dtype_and_shape(
|
||||
name: str, feature: dict, value: np.ndarray | PILImage.Image | str
|
||||
):
|
||||
def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
|
||||
expected_dtype = feature["dtype"]
|
||||
expected_shape = feature["shape"]
|
||||
if is_valid_numpy_dtype_string(expected_dtype):
|
||||
@@ -789,9 +749,7 @@ def validate_feature_dtype_and_shape(
|
||||
elif expected_dtype == "string":
|
||||
return validate_feature_string(name, value)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"The feature dtype '{expected_dtype}' is not implemented yet."
|
||||
)
|
||||
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
||||
|
||||
|
||||
def validate_feature_numpy_array(
|
||||
@@ -813,17 +771,13 @@ def validate_feature_numpy_array(
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_image_or_video(
|
||||
name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image
|
||||
):
|
||||
def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
|
||||
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_shape = value.shape
|
||||
c, h, w = expected_shape
|
||||
if len(actual_shape) != 3 or (
|
||||
actual_shape != (c, h, w) and actual_shape != (h, w, c)
|
||||
):
|
||||
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
|
||||
elif isinstance(value, PILImage.Image):
|
||||
pass
|
||||
@@ -854,9 +808,7 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
|
||||
)
|
||||
|
||||
if episode_buffer["size"] == 0:
|
||||
raise ValueError(
|
||||
"You must add one or several frames with `add_frame` before calling `add_episode`."
|
||||
)
|
||||
raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
|
||||
|
||||
buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
|
||||
if not buffer_keys == set(features):
|
||||
|
||||
@@ -218,9 +218,7 @@ def get_features_from_hf_dataset(
|
||||
dtype = ft.feature.dtype
|
||||
shape = (ft.length,)
|
||||
motor_names = (
|
||||
robot_config["names"][key]
|
||||
if robot_config
|
||||
else [f"motor_{i}" for i in range(ft.length)]
|
||||
robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)]
|
||||
)
|
||||
assert len(motor_names) == shape[0]
|
||||
names = {"motors": motor_names}
|
||||
@@ -244,15 +242,11 @@ def get_features_from_hf_dataset(
|
||||
return features
|
||||
|
||||
|
||||
def add_task_index_by_episodes(
|
||||
dataset: Dataset, tasks_by_episodes: dict
|
||||
) -> tuple[Dataset, list[str]]:
|
||||
def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
|
||||
df = dataset.to_pandas()
|
||||
tasks = list(set(tasks_by_episodes.values()))
|
||||
tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)}
|
||||
episodes_to_task_index = {
|
||||
ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()
|
||||
}
|
||||
episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int)
|
||||
|
||||
features = dataset.features
|
||||
@@ -269,19 +263,10 @@ def add_task_index_from_tasks_col(
|
||||
# HACK: This is to clean some of the instructions in our version of Open X datasets
|
||||
prefix_to_clean = "tf.Tensor(b'"
|
||||
suffix_to_clean = "', shape=(), dtype=string)"
|
||||
df[tasks_col] = (
|
||||
df[tasks_col]
|
||||
.str.removeprefix(prefix_to_clean)
|
||||
.str.removesuffix(suffix_to_clean)
|
||||
)
|
||||
df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean)
|
||||
|
||||
# Create task_index col
|
||||
tasks_by_episode = (
|
||||
df.groupby("episode_index")[tasks_col]
|
||||
.unique()
|
||||
.apply(lambda x: x.tolist())
|
||||
.to_dict()
|
||||
)
|
||||
tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict()
|
||||
tasks = df[tasks_col].unique().tolist()
|
||||
tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)}
|
||||
df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int)
|
||||
@@ -306,9 +291,7 @@ def split_parquet_by_episodes(
|
||||
for ep_chunk in range(total_chunks):
|
||||
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
|
||||
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
||||
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(
|
||||
episode_chunk=ep_chunk
|
||||
)
|
||||
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
|
||||
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
||||
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
@@ -340,9 +323,7 @@ def move_videos(
|
||||
videos_moved = False
|
||||
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
|
||||
if len(video_files) == 0:
|
||||
video_files = [
|
||||
str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")
|
||||
]
|
||||
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")]
|
||||
videos_moved = True # Videos have already been moved
|
||||
|
||||
assert len(video_files) == total_episodes * len(video_keys)
|
||||
@@ -373,9 +354,7 @@ def move_videos(
|
||||
target_path = DEFAULT_VIDEO_PATH.format(
|
||||
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
|
||||
)
|
||||
video_file = V1_VIDEO_FILE.format(
|
||||
video_key=vid_key, episode_index=ep_idx
|
||||
)
|
||||
video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
|
||||
if len(video_dirs) == 1:
|
||||
video_path = video_dirs[0] / video_file
|
||||
else:
|
||||
@@ -392,9 +371,7 @@ def move_videos(
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
def fix_lfs_video_files_tracking(
|
||||
work_dir: Path, lfs_untracked_videos: list[str]
|
||||
) -> None:
|
||||
def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None:
|
||||
"""
|
||||
HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
|
||||
there's no other option than to download the actual files and reupload them with lfs tracking.
|
||||
@@ -418,14 +395,10 @@ def fix_lfs_video_files_tracking(
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
def fix_gitattributes(
|
||||
work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path
|
||||
) -> None:
|
||||
def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None:
|
||||
shutil.copyfile(clean_gittatributes, current_gittatributes)
|
||||
subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
|
||||
subprocess.run(
|
||||
["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True
|
||||
)
|
||||
subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True)
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
@@ -462,9 +435,7 @@ def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[st
|
||||
return [f for f in video_files if f not in lfs_tracked_files]
|
||||
|
||||
|
||||
def get_videos_info(
|
||||
repo_id: str, local_dir: Path, video_keys: list[str], branch: str
|
||||
) -> dict:
|
||||
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
|
||||
# Assumes first episode
|
||||
video_files = [
|
||||
DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
|
||||
@@ -539,31 +510,19 @@ def convert_dataset(
|
||||
if single_task:
|
||||
tasks_by_episodes = dict.fromkeys(episode_indices, single_task)
|
||||
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
||||
tasks_by_episodes = {
|
||||
ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()
|
||||
}
|
||||
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
elif tasks_path:
|
||||
tasks_by_episodes = load_json(tasks_path)
|
||||
tasks_by_episodes = {
|
||||
int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()
|
||||
}
|
||||
tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
|
||||
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
||||
tasks_by_episodes = {
|
||||
ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()
|
||||
}
|
||||
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
elif tasks_col:
|
||||
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(
|
||||
dataset, tasks_col
|
||||
)
|
||||
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
assert set(tasks) == {
|
||||
task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks
|
||||
}
|
||||
tasks = [
|
||||
{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)
|
||||
]
|
||||
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
|
||||
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
|
||||
write_jsonlines(tasks, v20_dir / TASKS_PATH)
|
||||
features["task_index"] = {
|
||||
"dtype": "int64",
|
||||
@@ -593,9 +552,7 @@ def convert_dataset(
|
||||
clean_gitattr,
|
||||
branch,
|
||||
)
|
||||
videos_info = get_videos_info(
|
||||
repo_id, v1x_dir, video_keys=video_keys, branch=branch
|
||||
)
|
||||
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
|
||||
for key in video_keys:
|
||||
features[key]["shape"] = (
|
||||
videos_info[key].pop("video.height"),
|
||||
@@ -603,22 +560,15 @@ def convert_dataset(
|
||||
videos_info[key].pop("video.channels"),
|
||||
)
|
||||
features[key]["video_info"] = videos_info[key]
|
||||
assert math.isclose(
|
||||
videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3
|
||||
)
|
||||
assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
|
||||
if "encoding" in metadata_v1:
|
||||
assert (
|
||||
videos_info[key]["video.pix_fmt"]
|
||||
== metadata_v1["encoding"]["pix_fmt"]
|
||||
)
|
||||
assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
|
||||
else:
|
||||
assert metadata_v1.get("video", 0) == 0
|
||||
videos_info = None
|
||||
|
||||
# Split data into 1 parquet file by episode
|
||||
episode_lengths = split_parquet_by_episodes(
|
||||
dataset, total_episodes, total_chunks, v20_dir
|
||||
)
|
||||
episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
|
||||
|
||||
if robot_config is not None:
|
||||
robot_type = robot_config.type
|
||||
@@ -656,14 +606,10 @@ def convert_dataset(
|
||||
}
|
||||
write_json(metadata_v2_0, v20_dir / INFO_PATH)
|
||||
convert_stats_to_json(v1x_dir, v20_dir)
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs
|
||||
)
|
||||
card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(
|
||||
repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch
|
||||
)
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(
|
||||
@@ -674,9 +620,7 @@ def convert_dataset(
|
||||
)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(
|
||||
repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch
|
||||
)
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
|
||||
|
||||
hub_api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
|
||||
@@ -35,30 +35,22 @@ def fix_dataset(repo_id: str) -> str:
|
||||
|
||||
dataset_info = get_dataset_config_info(repo_id, "default")
|
||||
with SuppressWarnings():
|
||||
lerobot_metadata = LeRobotDatasetMetadata(
|
||||
repo_id, revision=V20, force_cache_sync=True
|
||||
)
|
||||
lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True)
|
||||
|
||||
meta_features = {
|
||||
key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"
|
||||
}
|
||||
meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"}
|
||||
parquet_features = set(dataset_info.features)
|
||||
|
||||
diff_parquet_meta = parquet_features - meta_features
|
||||
diff_meta_parquet = meta_features - parquet_features
|
||||
|
||||
if diff_parquet_meta:
|
||||
raise ValueError(
|
||||
f"In parquet not in info.json: {parquet_features - meta_features}"
|
||||
)
|
||||
raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}")
|
||||
|
||||
if not diff_meta_parquet:
|
||||
return f"{repo_id}: skipped (no diff)"
|
||||
|
||||
if diff_meta_parquet:
|
||||
logging.warning(
|
||||
f"In info.json not in parquet: {meta_features - parquet_features}"
|
||||
)
|
||||
logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
|
||||
assert diff_meta_parquet == {"language_instruction"}
|
||||
lerobot_metadata.features.pop("language_instruction")
|
||||
write_info(lerobot_metadata.info, lerobot_metadata.root)
|
||||
|
||||
@@ -99,9 +99,7 @@ def convert_dataset(
|
||||
repo_type="dataset",
|
||||
)
|
||||
|
||||
hub_api.create_tag(
|
||||
repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset"
|
||||
)
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -26,9 +26,7 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import write_episode_stats
|
||||
|
||||
|
||||
def sample_episode_video_frames(
|
||||
dataset: LeRobotDataset, episode_index: int, ft_key: str
|
||||
) -> np.ndarray:
|
||||
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
|
||||
ep_len = dataset.meta.episodes[episode_index]["length"]
|
||||
sampled_indices = sample_indices(ep_len)
|
||||
query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
|
||||
@@ -51,14 +49,11 @@ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
||||
|
||||
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
|
||||
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
|
||||
ep_stats[key] = get_feature_stats(
|
||||
ep_ft_data, axis=axes_to_reduce, keepdims=keepdims
|
||||
)
|
||||
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
|
||||
|
||||
if ft["dtype"] in ["image", "video"]: # remove batch dim
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v, axis=0)
|
||||
for k, v in ep_stats[key].items()
|
||||
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
dataset.meta.episodes_stats[ep_idx] = ep_stats
|
||||
|
||||
@@ -65,9 +65,7 @@ def decode_video_frames(
|
||||
if backend == "torchcodec":
|
||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
||||
elif backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(
|
||||
video_path, timestamps, tolerance_s, backend
|
||||
)
|
||||
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
else:
|
||||
raise ValueError(f"Unsupported video backend: {backend}")
|
||||
|
||||
@@ -346,9 +344,7 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(
|
||||
ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||
)
|
||||
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
@@ -362,9 +358,7 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
"has_audio": True,
|
||||
"audio.channels": audio_stream_info.get("channels", None),
|
||||
"audio.codec": audio_stream_info.get("codec_name", None),
|
||||
"audio.bit_rate": int(audio_stream_info["bit_rate"])
|
||||
if audio_stream_info.get("bit_rate")
|
||||
else None,
|
||||
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
|
||||
"audio.sample_rate": int(audio_stream_info["sample_rate"])
|
||||
if audio_stream_info.get("sample_rate")
|
||||
else None,
|
||||
@@ -386,9 +380,7 @@ def get_video_info(video_path: Path | str) -> dict:
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(
|
||||
ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||
)
|
||||
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user