Merge remote-tracking branch 'origin/main' into user/rcadene/2025_02_19_port_openx

This commit is contained in:
Remi Cadene
2025-03-01 19:17:18 +00:00
123 changed files with 2489 additions and 629 deletions

View File

@@ -33,8 +33,22 @@ If you encounter a problem, contact LeRobot maintainers on [Discord](https://dis
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
"""
FUTURE_MESSAGE = """
The dataset you requested ({repo_id}) is only available in {version} format.
As we cannot ensure forward compatibility with it, please update your current version of lerobot.
"""
class BackwardCompatibilityError(Exception):
class CompatibilityError(Exception): ...
class BackwardCompatibilityError(CompatibilityError):
def __init__(self, repo_id: str, version: packaging.version.Version):
message = V2_MESSAGE.format(repo_id=repo_id, version=version)
super().__init__(message)
class ForwardCompatibilityError(CompatibilityError):
def __init__(self, repo_id: str, version: packaging.version.Version):
message = FUTURE_MESSAGE.format(repo_id=repo_id, version=version)
super().__init__(message)

View File

@@ -92,7 +92,7 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu
axes_to_reduce = (0, 2, 3) # keep channel dim
keepdims = True
else:
ep_ft_array = data # data is alreay a np.ndarray
ep_ft_array = data # data is already a np.ndarray
axes_to_reduce = 0 # compute stats over the first axis
keepdims = data.ndim == 1 # keep as np.array

View File

@@ -83,10 +83,13 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
)
if isinstance(cfg.dataset.repo_id, str):
ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id, revision=cfg.dataset.revision)
ds_meta = LeRobotDatasetMetadata(
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
)
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
episodes=cfg.dataset.episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,

View File

@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import logging
import shutil
from pathlib import Path
@@ -20,13 +21,14 @@ from typing import Callable
import datasets
import numpy as np
import packaging.version
import PIL.Image
import torch
import torch.utils
from datasets import load_dataset
from datasets import concatenate_datasets, load_dataset
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.constants import REPOCARD_NAME
from packaging import version
from huggingface_hub.errors import RevisionNotFoundError
from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
@@ -43,12 +45,14 @@ from lerobot.common.datasets.utils import (
check_version_compatibility,
create_empty_dataset_info,
create_lerobot_dataset_card,
embed_images,
get_delta_indices,
get_episode_data_index,
get_features_from_robot,
get_hf_features_from_features,
get_safe_revision,
get_safe_version,
hf_transform_to_torch,
is_valid_version,
load_episodes,
load_episodes_stats,
load_info,
@@ -60,7 +64,6 @@ from lerobot.common.datasets.utils import (
write_episode_stats,
write_info,
write_json,
write_parquet,
)
from lerobot.common.datasets.video_utils import (
VideoFrame,
@@ -70,7 +73,6 @@ from lerobot.common.datasets.video_utils import (
)
from lerobot.common.robot_devices.robots.utils import Robot
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
CODEBASE_VERSION = "v2.1"
@@ -91,18 +93,19 @@ class LeRobotDatasetMetadata:
raise FileNotFoundError
self.load_metadata()
except (FileNotFoundError, NotADirectoryError):
if is_valid_version(self.revision):
self.revision = get_safe_version(self.repo_id, self.revision)
(self.root / "meta").mkdir(exist_ok=True, parents=True)
self.revision = get_safe_revision(self.repo_id, self.revision)
self.pull_from_repo(allow_patterns="meta/")
self.load_metadata()
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
def load_metadata(self):
self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks, self.task_to_task_index = load_tasks(self.root)
self.episodes = load_episodes(self.root)
if version.parse(self._version) < version.parse("v2.1"):
if self._version < packaging.version.parse("v2.1"):
self.stats = load_stats(self.root)
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
else:
@@ -124,9 +127,9 @@ class LeRobotDatasetMetadata:
)
@property
def _version(self) -> str:
def _version(self) -> packaging.version.Version:
"""Codebase version used to create this dataset."""
return self.info["codebase_version"]
return packaging.version.parse(self.info["codebase_version"])
def get_data_file_path(self, ep_index: int) -> Path:
ep_chunk = self.get_episode_chunk(ep_index)
@@ -225,7 +228,7 @@ class LeRobotDatasetMetadata:
def add_task(self, task: str):
"""
Given a task in natural language, add it to the dictionnary of tasks.
Given a task in natural language, add it to the dictionary of tasks.
"""
if task in self.task_to_task_index:
raise ValueError(f"The task '{task}' already exists and can't be added twice.")
@@ -388,7 +391,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
- info contains various information about the dataset like shapes, keys, fps etc.
- stats stores the dataset statistics of the different modalities for normalization
- tasks contains the prompts for each task of the dataset, which can be used for
task-conditionned training.
task-conditioned training.
- hf_dataset (from datasets.Dataset), which will read any values from parquet files.
- videos (optional) from which frames are loaded to be synchronous with data from parquet files.
@@ -483,7 +486,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.meta = LeRobotDatasetMetadata(
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
)
if self.episodes is not None and version.parse(self.meta._version) >= version.parse("v2.1"):
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
self.stats = aggregate_stats(episodes_stats)
@@ -494,14 +497,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
self.hf_dataset = self.load_hf_dataset()
except (AssertionError, FileNotFoundError, NotADirectoryError):
self.revision = get_safe_revision(self.repo_id, self.revision)
self.revision = get_safe_version(self.repo_id, self.revision)
self.download_episodes(download_videos)
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
# Check timestamps
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
# Setup delta_indices
if self.delta_timestamps is not None:
@@ -513,6 +519,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
branch: str | None = None,
tags: list | None = None,
license: str | None = "apache-2.0",
tag_version: bool = True,
push_videos: bool = True,
private: bool = False,
allow_patterns: list[str] | str | None = None,
@@ -558,6 +565,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
if tag_version:
with contextlib.suppress(RevisionNotFoundError):
hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
def pull_from_repo(
self,
allow_patterns: list[str] | str | None = None,
@@ -611,7 +623,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
# TODO(aliberts): hf_dataset.set_format("torch")
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def create_hf_dataset(self) -> datasets.Dataset:
features = get_hf_features_from_features(self.features)
ft_dict = {col: [] for col in features}
hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
# TODO(aliberts): hf_dataset.set_format("torch")
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
@property
@@ -836,7 +856,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
# Add new tasks to the tasks dictionnary
# Add new tasks to the tasks dictionary
for task in episode_tasks:
task_index = self.meta.get_task_index(task)
if task_index is None:
@@ -864,9 +884,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
# `meta.save_episode` be executed after encoding the videos
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
check_timestamps_sync(
episode_buffer["timestamp"],
episode_buffer["episode_index"],
ep_data_index_np,
self.fps,
self.tolerance_s,
)
video_files = list(self.root.rglob("*.mp4"))
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
@@ -885,9 +911,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
ep_dataset = embed_images(ep_dataset)
self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset])
self.hf_dataset.set_transform(hf_transform_to_torch)
ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
write_parquet(ep_dataset, ep_data_path)
ep_dataset.to_parquet(ep_data_path)
def clear_episode_buffer(self) -> None:
episode_index = self.episode_buffer["episode_index"]
@@ -995,7 +1024,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.episode_buffer = obj.create_episode_buffer()
obj.episodes = None
obj.hf_dataset = None
obj.hf_dataset = obj.create_hf_dataset()
obj.image_transforms = None
obj.delta_timestamps = None
obj.delta_indices = None

View File

@@ -1,56 +0,0 @@
## Using / Updating `CODEBASE_VERSION` (for maintainers)
Since our dataset pushed to the hub are decoupled with the evolution of this repo, we ensure compatibility of
the datasets with our code, we use a `CODEBASE_VERSION` (defined in
lerobot/common/datasets/lerobot_dataset.py) variable.
For instance, [`lerobot/pusht`](https://huggingface.co/datasets/lerobot/pusht) has many versions to maintain backward compatibility between LeRobot codebase versions:
- [v1.0](https://huggingface.co/datasets/lerobot/pusht/tree/v1.0)
- [v1.1](https://huggingface.co/datasets/lerobot/pusht/tree/v1.1)
- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2)
- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3)
- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4)
- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5)
- [v1.6](https://huggingface.co/datasets/lerobot/pusht/tree/v1.6) <-- last version
- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version
Starting with v1.6, every dataset pushed to the hub or saved locally also have this version number in their
`info.json` metadata.
### Uploading a new dataset
If you are pushing a new dataset, you don't need to worry about any of the instructions below, nor to be
compatible with previous codebase versions. The `push_dataset_to_hub.py` script will automatically tag your
dataset with the current `CODEBASE_VERSION`.
### Updating an existing dataset
If you want to update an existing dataset, you need to change the `CODEBASE_VERSION` from `lerobot_dataset.py`
before running `push_dataset_to_hub.py`. This is especially useful if you introduce a breaking change
intentionally or not (i.e. something not backward compatible such as modifying the reward functions used,
deleting some frames at the end of an episode, etc.). That way, people running a previous version of the
codebase won't be affected by your change and backward compatibility is maintained.
However, you will need to update the version of ALL the other datasets so that they have the new
`CODEBASE_VERSION` as a branch in their hugging face dataset repository. Don't worry, there is an easy way
that doesn't require to run `push_dataset_to_hub.py`. You can just "branch-out" from the `main` branch on HF
dataset repo by running this script which corresponds to a `git checkout -b` (so no copy or upload needed):
```python
from huggingface_hub import HfApi
from lerobot import available_datasets
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
api = HfApi()
for repo_id in available_datasets:
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
branches = [b.name for b in dataset_info.branches]
if CODEBASE_VERSION in branches:
print(f"{repo_id} already @{CODEBASE_VERSION}, skipping.")
continue
else:
# Now create a branch named after the new version by branching out from "main"
# which is expected to be the preceding version
api.create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION, revision="main")
print(f"{repo_id} successfully updated @{CODEBASE_VERSION}")
```

View File

@@ -152,7 +152,7 @@ def download_raw(raw_dir: Path, repo_id: str):
stacklevel=1,
)
# Send warning if raw_dir isn't well formated
# Send warning if raw_dir isn't well formatted
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
warnings.warn(
f"""`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that

View File

@@ -68,9 +68,9 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
modality_df,
on="timestamp_utc",
# "nearest" is the best option over "backward", since the latter can desynchronizes camera timestamps by
# matching timestamps that are too far appart, in order to fit the backward constraints. It's not the case for "nearest".
# matching timestamps that are too far apart, in order to fit the backward constraints. It's not the case for "nearest".
# However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps.
# are too far appart.
# are too far apart.
direction="nearest",
tolerance=pd.Timedelta(f"{1 / fps} seconds"),
)
@@ -126,7 +126,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
videos_dir.parent.mkdir(parents=True, exist_ok=True)
videos_dir.symlink_to((raw_dir / "videos").absolute())
# sanity check the video paths are well formated
# sanity check the video paths are well formatted
for key in df:
if "observation.images." not in key:
continue
@@ -143,7 +143,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
# it is the case for video_frame dictionary = [{"path": ..., "timestamp": ...}]
data_dict[key] = [video_frame[0] for video_frame in df[key].values]
# sanity check the video path is well formated
# sanity check the video path is well formatted
video_path = videos_dir.parent / data_dict[key][0]["path"]
if not video_path.exists():
raise ValueError(f"Video file not found in {video_path}")

View File

@@ -17,7 +17,7 @@
For all datasets in the RLDS format.
For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets.
NOTE: You need to install tensorflow and tensorflow_datsets before running this script.
NOTE: You need to install tensorflow and tensorflow_datasets before running this script.
Example:
python lerobot/scripts/push_dataset_to_hub.py \

View File

@@ -27,15 +27,19 @@ from typing import Any
import datasets
import jsonlines
import numpy as np
import pyarrow.compute as pc
import packaging.version
import torch
from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
from packaging import version
from huggingface_hub.errors import RevisionNotFoundError
from PIL import Image as PILImage
from torchvision import transforms
from lerobot.common.datasets.backward_compatibility import V21_MESSAGE, BackwardCompatibilityError
from lerobot.common.datasets.backward_compatibility import (
V21_MESSAGE,
BackwardCompatibilityError,
ForwardCompatibilityError,
)
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
@@ -129,13 +133,13 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
return unflatten_dict(serialized_dict)
def write_parquet(dataset: datasets.Dataset, fpath: Path) -> None:
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
# Embed image bytes into the table before saving to parquet
format = dataset.format
dataset = dataset.with_format("arrow")
dataset = dataset.map(embed_table_storage, batched=False)
dataset = dataset.with_format(**format)
dataset.to_parquet(fpath)
return dataset
def load_json(fpath: Path) -> Any:
@@ -219,7 +223,7 @@ def load_episodes(local_dir: Path) -> dict:
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
# We wrap episode_stats in a dictionnary since `episode_stats["episode_index"]`
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
# is a dictionary of stats and not an integer.
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
@@ -269,38 +273,91 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
return items_dict
def is_valid_version(version: str) -> bool:
try:
packaging.version.parse(version)
return True
except packaging.version.InvalidVersion:
return False
def check_version_compatibility(
repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True
repo_id: str,
version_to_check: str | packaging.version.Version,
current_version: str | packaging.version.Version,
enforce_breaking_major: bool = True,
) -> None:
v_check = version.parse(version_to_check)
v_current = version.parse(current_version)
v_check = (
packaging.version.parse(version_to_check)
if not isinstance(version_to_check, packaging.version.Version)
else version_to_check
)
v_current = (
packaging.version.parse(current_version)
if not isinstance(current_version, packaging.version.Version)
else current_version
)
if v_check.major < v_current.major and enforce_breaking_major:
raise BackwardCompatibilityError(repo_id, v_check)
elif v_check.minor < v_current.minor:
logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=version_to_check))
logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check))
def get_repo_versions(repo_id: str) -> list[version.Version]:
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
"""Returns available valid versions (branches and tags) on given repo."""
api = HfApi()
repo_refs = api.list_repo_refs(repo_id, repo_type="dataset")
repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags]
repo_versions = []
for ref in repo_refs:
with contextlib.suppress(version.InvalidVersion):
repo_versions.append(version.parse(ref))
with contextlib.suppress(packaging.version.InvalidVersion):
repo_versions.append(packaging.version.parse(ref))
return repo_versions
def get_safe_revision(repo_id: str, revision: str) -> str:
"""Returns the version if available on repo, otherwise return the latest available."""
api = HfApi()
if api.revision_exists(repo_id, revision, repo_type="dataset"):
return revision
def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> str:
"""
Returns the version if available on repo or the latest compatible one.
Otherwise, will throw a `CompatibilityError`.
"""
target_version = (
packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version
)
hub_versions = get_repo_versions(repo_id)
return f"v{max(hub_versions)}"
if not hub_versions:
raise RevisionNotFoundError(
f"""Your dataset must be tagged with a codebase version.
Assuming _version_ is the codebase_version value in the info.json, you can run this:
```python
from huggingface_hub import HfApi
hub_api = HfApi()
hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset")
```
"""
)
if target_version in hub_versions:
return f"v{target_version}"
compatibles = [
v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor
]
if compatibles:
return_version = max(compatibles)
if return_version < target_version:
logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}")
return f"v{return_version}"
lower_major = [v for v in hub_versions if v.major < target_version.major]
if lower_major:
raise BackwardCompatibilityError(repo_id, max(lower_major))
upper_versions = [v for v in hub_versions if v > target_version]
assert len(upper_versions) > 0
raise ForwardCompatibilityError(repo_id, min(upper_versions))
def get_hf_features_from_features(features: dict) -> datasets.Features:
@@ -402,82 +459,79 @@ def get_episode_data_index(
if episodes is not None:
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
cumulative_lenghts = list(accumulate(episode_lengths.values()))
cumulative_lengths = list(accumulate(episode_lengths.values()))
return {
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
"to": torch.LongTensor(cumulative_lenghts),
}
def calculate_total_episode(
hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True
) -> dict[str, torch.Tensor]:
episode_indices = sorted(hf_dataset.unique("episode_index"))
total_episodes = len(episode_indices)
if raise_if_not_contiguous and episode_indices != list(range(total_episodes)):
raise ValueError("episode_index values are not sorted and contiguous.")
return total_episodes
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]:
episode_lengths = []
table = hf_dataset.data.table
total_episodes = calculate_total_episode(hf_dataset)
for ep_idx in range(total_episodes):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
episode_lengths.insert(ep_idx, len(ep_table))
cumulative_lenghts = list(accumulate(episode_lengths))
return {
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
"to": torch.LongTensor(cumulative_lenghts),
"from": torch.LongTensor([0] + cumulative_lengths[:-1]),
"to": torch.LongTensor(cumulative_lengths),
}
def check_timestamps_sync(
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
timestamps: np.ndarray,
episode_indices: np.ndarray,
episode_data_index: dict[str, np.ndarray],
fps: int,
tolerance_s: float,
raise_value_error: bool = True,
) -> bool:
"""
This check is to make sure that each timestamps is separated to the next by 1/fps +/- tolerance to
account for possible numerical error.
"""
timestamps = torch.stack(hf_dataset["timestamp"])
diffs = torch.diff(timestamps)
within_tolerance = torch.abs(diffs - 1 / fps) <= tolerance_s
This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance
to account for possible numerical error.
# We mask differences between the timestamp at the end of an episode
# and the one at the start of the next episode since these are expected
# to be outside tolerance.
mask = torch.ones(len(diffs), dtype=torch.bool)
ignored_diffs = episode_data_index["to"][:-1] - 1
Args:
timestamps (np.ndarray): Array of timestamps in seconds.
episode_indices (np.ndarray): Array indicating the episode index for each timestamp.
episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to',
which identifies indices for the end of each episode.
fps (int): Frames per second. Used to check the expected difference between consecutive timestamps.
tolerance_s (float): Allowed deviation from the expected (1/fps) difference.
raise_value_error (bool): Whether to raise a ValueError if the check fails.
Returns:
bool: True if all checked timestamp differences lie within tolerance, False otherwise.
Raises:
ValueError: If the check fails and `raise_value_error` is True.
"""
if timestamps.shape != episode_indices.shape:
raise ValueError(
"timestamps and episode_indices should have the same shape. "
f"Found {timestamps.shape=} and {episode_indices.shape=}."
)
# Consecutive differences
diffs = np.diff(timestamps)
within_tolerance = np.abs(diffs - (1.0 / fps)) <= tolerance_s
# Mask to ignore differences at the boundaries between episodes
mask = np.ones(len(diffs), dtype=bool)
ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode
mask[ignored_diffs] = False
filtered_within_tolerance = within_tolerance[mask]
if not torch.all(filtered_within_tolerance):
# Check if all remaining diffs are within tolerance
if not np.all(filtered_within_tolerance):
# Track original indices before masking
original_indices = torch.arange(len(diffs))
original_indices = np.arange(len(diffs))
filtered_indices = original_indices[mask]
outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance) # .squeeze()
outside_tolerance_filtered_indices = np.nonzero(~filtered_within_tolerance)[0]
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
episode_indices = torch.stack(hf_dataset["episode_index"])
outside_tolerances = []
for idx in outside_tolerance_indices:
entry = {
"timestamps": [timestamps[idx], timestamps[idx + 1]],
"diff": diffs[idx],
"episode_index": episode_indices[idx].item(),
"episode_index": episode_indices[idx].item()
if hasattr(episode_indices[idx], "item")
else episode_indices[idx],
}
outside_tolerances.append(entry)
if raise_value_error:
raise ValueError(
f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
This might be due to synchronization issues with timestamps during data collection.
This might be due to synchronization issues during data collection.
\n{pformat(outside_tolerances)}"""
)
return False

View File

@@ -31,6 +31,7 @@ from lerobot.common.robot_devices.robots.configs import AlohaRobotConfig
LOCAL_DIR = Path("data/")
# spellchecker:off
ALOHA_MOBILE_INFO = {
"robot_config": AlohaRobotConfig(),
"license": "mit",
@@ -856,6 +857,7 @@ DATASETS = {
}""").lstrip(),
},
}
# spellchecker:on
def batch_convert():

View File

@@ -17,7 +17,7 @@
"""
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 1.6 to
2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
for each of the task performed in the dataset. This will allow to easily train models with task-conditionning.
for each of the task performed in the dataset. This will allow to easily train models with task-conditioning.
We support 3 different scenarios for these tasks (see instructions below):
1. Single task dataset: all episodes of your dataset have the same single task.
@@ -130,7 +130,7 @@ from lerobot.common.datasets.utils import (
create_branch,
create_lerobot_dataset_card,
flatten_dict,
get_safe_revision,
get_safe_version,
load_json,
unflatten_dict,
write_json,
@@ -443,7 +443,7 @@ def convert_dataset(
test_branch: str | None = None,
**card_kwargs,
):
v1 = get_safe_revision(repo_id, V16)
v1 = get_safe_version(repo_id, V16)
v1x_dir = local_dir / V16 / repo_id
v20_dir = local_dir / V20 / repo_id
v1x_dir.mkdir(parents=True, exist_ok=True)

View File

@@ -0,0 +1,73 @@
import logging
import traceback
from pathlib import Path
from datasets import get_dataset_config_info
from huggingface_hub import HfApi
from lerobot import available_datasets
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.common.datasets.utils import INFO_PATH, write_info
from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import V20, SuppressWarnings
LOCAL_DIR = Path("data/")
hub_api = HfApi()
def fix_dataset(repo_id: str) -> str:
if not hub_api.revision_exists(repo_id, V20, repo_type="dataset"):
return f"{repo_id}: skipped (not in {V20})."
dataset_info = get_dataset_config_info(repo_id, "default")
with SuppressWarnings():
lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True)
meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"}
parquet_features = set(dataset_info.features)
diff_parquet_meta = parquet_features - meta_features
diff_meta_parquet = meta_features - parquet_features
if diff_parquet_meta:
raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}")
if not diff_meta_parquet:
return f"{repo_id}: skipped (no diff)"
if diff_meta_parquet:
logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
assert diff_meta_parquet == {"language_instruction"}
lerobot_metadata.features.pop("language_instruction")
write_info(lerobot_metadata.info, lerobot_metadata.root)
commit_info = hub_api.upload_file(
path_or_fileobj=lerobot_metadata.root / INFO_PATH,
path_in_repo=INFO_PATH,
repo_id=repo_id,
repo_type="dataset",
revision=V20,
commit_message="Remove 'language_instruction'",
create_pr=True,
)
return f"{repo_id}: success - PR: {commit_info.pr_url}"
def batch_fix():
status = {}
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
logfile = LOCAL_DIR / "fix_features_v20.txt"
for num, repo_id in enumerate(available_datasets):
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
print("---------------------------------------------------------")
try:
status = fix_dataset(repo_id)
except Exception:
status = f"{repo_id}: failed\n {traceback.format_exc()}"
logging.info(status)
with open(logfile, "a") as file:
file.write(status + "\n")
if __name__ == "__main__":
batch_fix()

View File

@@ -21,8 +21,10 @@ This script is for internal use to convert all datasets under the 'lerobot' hub
import traceback
from pathlib import Path
from huggingface_hub import HfApi
from lerobot import available_datasets
from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import convert_dataset
from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import V21, convert_dataset
LOCAL_DIR = Path("data/")
@@ -31,19 +33,21 @@ def batch_convert():
status = {}
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
logfile = LOCAL_DIR / "conversion_log_v21.txt"
hub_api = HfApi()
for num, repo_id in enumerate(available_datasets):
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
print("---------------------------------------------------------")
try:
convert_dataset(repo_id)
status = f"{repo_id}: success."
with open(logfile, "a") as file:
file.write(status + "\n")
if hub_api.revision_exists(repo_id, V21, repo_type="dataset"):
status = f"{repo_id}: success (already in {V21})."
else:
convert_dataset(repo_id)
status = f"{repo_id}: success."
except Exception:
status = f"{repo_id}: failed\n {traceback.format_exc()}"
with open(logfile, "a") as file:
file.write(status + "\n")
continue
with open(logfile, "a") as file:
file.write(status + "\n")
if __name__ == "__main__":

View File

@@ -48,7 +48,7 @@ def convert_dataset(
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
if (dataset.root / EPISODES_STATS_PATH).is_file():
raise FileExistsError("episodes_stats.jsonl already exists.")
(dataset.root / EPISODES_STATS_PATH).unlink()
convert_stats(dataset, num_workers=num_workers)
ref_stats = load_stats(dataset.root)
@@ -57,7 +57,7 @@ def convert_dataset(
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
write_info(dataset.meta.info, dataset.root)
dataset.push_to_hub(branch=branch, allow_patterns="meta/")
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
# delete old stats.json file
if (dataset.root / STATS_PATH).is_file:

View File

@@ -65,7 +65,7 @@ def check_aggregate_stats(
dataset: LeRobotDataset,
reference_stats: dict[str, dict[str, np.ndarray]],
video_rtol_atol: tuple[float] = (1e-2, 1e-2),
default_rtol_atol: tuple[float] = (5e-6, 0.0),
default_rtol_atol: tuple[float] = (5e-6, 6e-5),
):
"""Verifies that the aggregated stats from episodes_stats are close to reference stats."""
agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values()))

View File

@@ -73,7 +73,7 @@ def decode_video_frames_torchvision(
last_ts = max(timestamps)
# access closest key frame of the first requested frame
# Note: closest key frame timestamp is usally smaller than `first_ts` (e.g. key frame can be the first frame of the video)
# Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
reader.seek(first_ts, keyframes_only=keyframes_only)