[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-24 13:16:38 +00:00
committed by AdilZouitine
parent 761a2dbcb3
commit 8e6d5f504c
97 changed files with 1596 additions and 492 deletions

View File

@@ -108,7 +108,9 @@ class LeRobotDatasetMetadata:
self.episodes = load_episodes(self.root)
if self._version < packaging.version.parse("v2.1"):
self.stats = load_stats(self.root)
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
self.episodes_stats = backward_compatible_episodes_stats(
self.stats, self.episodes
)
else:
self.episodes_stats = load_episodes_stats(self.root)
self.stats = aggregate_stats(list(self.episodes_stats.values()))
@@ -238,7 +240,9 @@ class LeRobotDatasetMetadata:
Given a task in natural language, add it to the dictionary of tasks.
"""
if task in self.task_to_task_index:
raise ValueError(f"The task '{task}' already exists and can't be added twice.")
raise ValueError(
f"The task '{task}' already exists and can't be added twice."
)
task_index = self.info["total_tasks"]
self.task_to_task_index[task] = task_index
@@ -281,7 +285,11 @@ class LeRobotDatasetMetadata:
write_episode(episode_dict, self.root)
self.episodes_stats[episode_index] = episode_stats
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
self.stats = (
aggregate_stats([self.stats, episode_stats])
if self.stats
else episode_stats
)
write_episode_stats(episode_index, episode_stats, self.root)
def update_video_info(self) -> None:
@@ -345,13 +353,17 @@ class LeRobotDatasetMetadata:
# as this would break the dict flattening in the stats computation, which uses '/' as separator
for key in features:
if "/" in key:
raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.")
raise ValueError(
f"Feature names should not contain '/'. Found '/' in feature '{key}'."
)
features = {**features, **DEFAULT_FEATURES}
obj.tasks, obj.task_to_task_index = {}, {}
obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
obj.info = create_empty_dataset_info(
CODEBASE_VERSION, fps, robot_type, features, use_videos
)
if len(obj.video_keys) > 0 and not use_videos:
raise ValueError()
write_json(obj.info, obj.root / INFO_PATH)
@@ -482,7 +494,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episodes = episodes
self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION
self.video_backend = video_backend if video_backend else get_safe_default_codec()
self.video_backend = (
video_backend if video_backend else get_safe_default_codec()
)
self.delta_indices = None
# Unused attributes
@@ -495,28 +509,39 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.meta = LeRobotDatasetMetadata(
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
)
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
if self.episodes is not None and self.meta._version >= packaging.version.parse(
"v2.1"
):
episodes_stats = [
self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes
]
self.stats = aggregate_stats(episodes_stats)
# Load actual data
try:
if force_cache_sync:
raise FileNotFoundError
assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
assert all(
(self.root / fpath).is_file()
for fpath in self.get_episodes_file_paths()
)
self.hf_dataset = self.load_hf_dataset()
except (AssertionError, FileNotFoundError, NotADirectoryError):
self.revision = get_safe_version(self.repo_id, self.revision)
self.download_episodes(download_videos)
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
self.episode_data_index = get_episode_data_index(
self.meta.episodes, self.episodes
)
# Check timestamps
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
check_timestamps_sync(
timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s
)
# Setup delta_indices
if self.delta_timestamps is not None:
@@ -568,7 +593,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
else:
hub_api.upload_folder(**upload_kwargs)
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
if not hub_api.file_exists(
self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch
):
card = create_lerobot_dataset_card(
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
)
@@ -576,8 +603,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
if tag_version:
with contextlib.suppress(RevisionNotFoundError):
hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
hub_api.delete_tag(
self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset"
)
hub_api.create_tag(
self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset"
)
def pull_from_repo(
self,
@@ -609,7 +640,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
def get_episodes_file_paths(self) -> list[Path]:
episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes))
episodes = (
self.episodes
if self.episodes is not None
else list(range(self.meta.total_episodes))
)
fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
if len(self.meta.video_keys) > 0:
video_files = [
@@ -640,7 +675,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
def create_hf_dataset(self) -> datasets.Dataset:
features = get_hf_features_from_features(self.features)
ft_dict = {col: [] for col in features}
hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
hf_dataset = datasets.Dataset.from_dict(
ft_dict, features=features, split="train"
)
# TODO(aliberts): hf_dataset.set_format("torch")
hf_dataset.set_transform(hf_transform_to_torch)
@@ -726,7 +763,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
if key not in self.meta.video_keys
}
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
def _query_videos(
self, query_timestamps: dict[str, list[float]], ep_idx: int
) -> dict[str, torch.Tensor]:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
@@ -735,7 +774,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
item = {}
for vid_key, query_ts in query_timestamps.items():
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
frames = decode_video_frames(
video_path, query_ts, self.tolerance_s, self.video_backend
)
item[vid_key] = frames.squeeze(0)
return item
@@ -789,7 +830,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
current_ep_idx = (
self.meta.total_episodes if episode_index is None else episode_index
)
ep_buffer = {}
# size and task are special cases that are not in self.features
ep_buffer["size"] = 0
@@ -887,7 +930,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
episode_tasks = list(set(tasks))
episode_index = episode_buffer["episode_index"]
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
episode_buffer["index"] = np.arange(
self.meta.total_frames, self.meta.total_frames + episode_length
)
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
# Add new tasks to the tasks dictionary
@@ -897,12 +942,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.meta.add_task(task)
# Given tasks in natural language, find their corresponding task indices
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
episode_buffer["task_index"] = np.array(
[self.meta.get_task_index(task) for task in tasks]
)
for key, ft in self.features.items():
# index, episode_index, task_index are already processed above, and image and video
# are processed separately by storing image path and frame info as meta data
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in [
"image",
"video",
]:
continue
episode_buffer[key] = np.stack(episode_buffer[key])
@@ -944,7 +994,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
ep_dataset = datasets.Dataset.from_dict(
episode_dict, features=self.hf_features, split="train"
)
ep_dataset = embed_images(ep_dataset)
self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset])
self.hf_dataset.set_transform(hf_transform_to_torch)
@@ -1063,7 +1115,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.delta_timestamps = None
obj.delta_indices = None
obj.episode_data_index = None
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
obj.video_backend = (
video_backend if video_backend is not None else get_safe_default_codec()
)
return obj
@@ -1088,7 +1142,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
super().__init__()
self.repo_ids = repo_ids
self.root = Path(root) if root else HF_LEROBOT_HOME
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
self.tolerances_s = (
tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
)
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class.
self._datasets = [