forked from tangger/lerobot
Remove local_files_only and use codebase_version instead of branches (#734)
This commit is contained in:
40
lerobot/common/datasets/backward_compatibility.py
Normal file
40
lerobot/common/datasets/backward_compatibility.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import packaging.version
|
||||
|
||||
V2_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is in {version} format.
|
||||
|
||||
We introduced a new format since v2.0 which is not backward compatible with v1.x.
|
||||
Please, use our conversion script. Modify the following command with your own task description:
|
||||
```
|
||||
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
|
||||
--repo-id {repo_id} \\
|
||||
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
|
||||
```
|
||||
|
||||
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.", "Insert the
|
||||
peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.", "Open the top
|
||||
cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped
|
||||
target.", "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the
|
||||
sweatshirt.", ...
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
"""
|
||||
|
||||
V21_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is in {version} format.
|
||||
While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
|
||||
stats instead of per-episode stats. Update your dataset stats to the new format using this command:
|
||||
```
|
||||
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py --repo-id={repo_id}
|
||||
```
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
"""
|
||||
|
||||
|
||||
class BackwardCompatibilityError(Exception):
|
||||
def __init__(self, repo_id: str, version: packaging.version.Version):
|
||||
message = V2_MESSAGE.format(repo_id=repo_id, version=version)
|
||||
super().__init__(message)
|
||||
@@ -83,15 +83,15 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
)
|
||||
|
||||
if isinstance(cfg.dataset.repo_id, str):
|
||||
ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id, local_files_only=cfg.dataset.local_files_only)
|
||||
ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id, revision=cfg.dataset.revision)
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
episodes=cfg.dataset.episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
local_files_only=cfg.dataset.local_files_only,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
@@ -27,6 +26,8 @@ import torch
|
||||
import torch.utils
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.constants import REPOCARD_NAME
|
||||
from packaging import version
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
||||
@@ -41,14 +42,13 @@ from lerobot.common.datasets.utils import (
|
||||
check_frame_features,
|
||||
check_timestamps_sync,
|
||||
check_version_compatibility,
|
||||
create_branch,
|
||||
create_empty_dataset_info,
|
||||
create_lerobot_dataset_card,
|
||||
get_delta_indices,
|
||||
get_episode_data_index,
|
||||
get_features_from_robot,
|
||||
get_hf_features_from_features,
|
||||
get_hub_safe_version,
|
||||
get_safe_revision,
|
||||
hf_transform_to_torch,
|
||||
load_episodes,
|
||||
load_episodes_stats,
|
||||
@@ -79,30 +79,35 @@ class LeRobotDatasetMetadata:
|
||||
self,
|
||||
repo_id: str,
|
||||
root: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
):
|
||||
self.repo_id = repo_id
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
|
||||
self.local_files_only = local_files_only
|
||||
|
||||
# Load metadata
|
||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
try:
|
||||
if force_cache_sync:
|
||||
raise FileNotFoundError
|
||||
self.load_metadata()
|
||||
except (FileNotFoundError, NotADirectoryError):
|
||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||
self.revision = get_safe_revision(self.repo_id, self.revision)
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.load_metadata()
|
||||
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
|
||||
def load_metadata(self):
|
||||
self.info = load_info(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
self.tasks, self.task_to_task_index = load_tasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
try:
|
||||
self.episodes_stats = load_episodes_stats(self.root)
|
||||
self.stats = aggregate_stats(list(self.episodes_stats.values()))
|
||||
except FileNotFoundError:
|
||||
logging.warning(
|
||||
f"""'episodes_stats.jsonl' not found. Using global dataset stats for each episode instead.
|
||||
Convert your dataset stats to the new format using this command:
|
||||
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py --repo-id={self.repo_id} """
|
||||
)
|
||||
if version.parse(self._version) < version.parse("v2.1"):
|
||||
self.stats = load_stats(self.root)
|
||||
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
|
||||
else:
|
||||
self.episodes_stats = load_episodes_stats(self.root)
|
||||
self.stats = aggregate_stats(list(self.episodes_stats.values()))
|
||||
|
||||
def pull_from_repo(
|
||||
self,
|
||||
@@ -112,17 +117,12 @@ class LeRobotDatasetMetadata:
|
||||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self._hub_version,
|
||||
revision=self.revision,
|
||||
local_dir=self.root,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
local_files_only=self.local_files_only,
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def _hub_version(self) -> str | None:
|
||||
return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
|
||||
|
||||
@property
|
||||
def _version(self) -> str:
|
||||
"""Codebase version used to create this dataset."""
|
||||
@@ -342,7 +342,7 @@ class LeRobotDatasetMetadata:
|
||||
if len(obj.video_keys) > 0 and not use_videos:
|
||||
raise ValueError()
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
obj.local_files_only = True
|
||||
obj.revision = None
|
||||
return obj
|
||||
|
||||
|
||||
@@ -355,8 +355,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
download_videos: bool = True,
|
||||
local_files_only: bool = False,
|
||||
video_backend: str | None = None,
|
||||
):
|
||||
"""
|
||||
@@ -366,7 +367,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
- On your local disk in the 'root' folder. This is typically the case when you recorded your
|
||||
dataset locally and you may or may not have pushed it to the hub yet. Instantiating this class
|
||||
with 'root' will load your dataset directly from disk. This can happen while you're offline (no
|
||||
internet connection), in that case, use local_files_only=True.
|
||||
internet connection).
|
||||
|
||||
- On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and not on
|
||||
your local disk in the 'root' folder. Instantiating this class with this 'repo_id' will download
|
||||
@@ -448,11 +449,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
|
||||
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
|
||||
multiples of 1/fps. Defaults to 1e-4.
|
||||
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
|
||||
commit hash. Defaults to current codebase version tag.
|
||||
sync_cache_first (bool, optional): Flag to sync and refresh local files first. If True and files
|
||||
are already present in the local cache, this will be faster. However, files loaded might not
|
||||
be in sync with the version on the hub, especially if you specified 'revision'. Defaults to
|
||||
False.
|
||||
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
||||
video files are already present on local disk, they won't be downloaded again. Defaults to
|
||||
True.
|
||||
local_files_only (bool, optional): Flag to use local files only. If True, no requests to the hub
|
||||
will be made. Defaults to False.
|
||||
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
|
||||
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
|
||||
"""
|
||||
@@ -463,9 +468,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.video_backend = video_backend if video_backend else "pyav"
|
||||
self.delta_indices = None
|
||||
self.local_files_only = local_files_only
|
||||
|
||||
# Unused attributes
|
||||
self.image_writer = None
|
||||
@@ -474,17 +479,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Load metadata
|
||||
self.meta = LeRobotDatasetMetadata(self.repo_id, self.root, self.local_files_only)
|
||||
if self.episodes is not None and self.meta._version == CODEBASE_VERSION:
|
||||
self.meta = LeRobotDatasetMetadata(
|
||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
if self.episodes is not None and version.parse(self.meta._version) >= version.parse("v2.1"):
|
||||
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
||||
self.stats = aggregate_stats(episodes_stats)
|
||||
|
||||
# Check version
|
||||
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
|
||||
|
||||
# Load actual data
|
||||
self.download_episodes(download_videos)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
try:
|
||||
if force_cache_sync:
|
||||
raise FileNotFoundError
|
||||
assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||
self.revision = get_safe_revision(self.repo_id, self.revision)
|
||||
self.download_episodes(download_videos)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
|
||||
# Check timestamps
|
||||
@@ -501,7 +513,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def push_to_hub(
|
||||
self,
|
||||
branch: str | None = None,
|
||||
create_card: bool = True,
|
||||
tags: list | None = None,
|
||||
license: str | None = "apache-2.0",
|
||||
push_videos: bool = True,
|
||||
@@ -528,7 +539,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
exist_ok=True,
|
||||
)
|
||||
if branch:
|
||||
create_branch(repo_id=self.repo_id, branch=branch, repo_type="dataset")
|
||||
hub_api.create_branch(
|
||||
repo_id=self.repo_id,
|
||||
branch=branch,
|
||||
revision=self.revision,
|
||||
repo_type="dataset",
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
hub_api.upload_folder(
|
||||
repo_id=self.repo_id,
|
||||
@@ -538,15 +555,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
if create_card:
|
||||
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
||||
)
|
||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
||||
|
||||
if not branch:
|
||||
create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset")
|
||||
|
||||
def pull_from_repo(
|
||||
self,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
@@ -555,11 +569,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self.meta._hub_version,
|
||||
revision=self.revision,
|
||||
local_dir=self.root,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
local_files_only=self.local_files_only,
|
||||
)
|
||||
|
||||
def download_episodes(self, download_videos: bool = True) -> None:
|
||||
@@ -573,17 +586,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
files = None
|
||||
ignore_patterns = None if download_videos else "videos/"
|
||||
if self.episodes is not None:
|
||||
files = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
||||
if len(self.meta.video_keys) > 0 and download_videos:
|
||||
video_files = [
|
||||
str(self.meta.get_video_file_path(ep_idx, vid_key))
|
||||
for vid_key in self.meta.video_keys
|
||||
for ep_idx in self.episodes
|
||||
]
|
||||
files += video_files
|
||||
files = self.get_episodes_file_paths()
|
||||
|
||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
|
||||
def get_episodes_file_paths(self) -> list[Path]:
|
||||
episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes))
|
||||
fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
|
||||
if len(self.meta.video_keys) > 0:
|
||||
video_files = [
|
||||
str(self.meta.get_video_file_path(ep_idx, vid_key))
|
||||
for vid_key in self.meta.video_keys
|
||||
for ep_idx in episodes
|
||||
]
|
||||
fpaths += video_files
|
||||
|
||||
return fpaths
|
||||
|
||||
def load_hf_dataset(self) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
if self.episodes is None:
|
||||
@@ -991,7 +1010,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
obj.repo_id = obj.meta.repo_id
|
||||
obj.root = obj.meta.root
|
||||
obj.local_files_only = obj.meta.local_files_only
|
||||
obj.revision = None
|
||||
obj.tolerance_s = tolerance_s
|
||||
obj.image_writer = None
|
||||
|
||||
@@ -1033,7 +1052,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
tolerances_s: dict | None = None,
|
||||
download_videos: bool = True,
|
||||
local_files_only: bool = False,
|
||||
video_backend: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -1051,7 +1069,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
delta_timestamps=delta_timestamps,
|
||||
tolerance_s=self.tolerances_s[repo_id],
|
||||
download_videos=download_videos,
|
||||
local_files_only=local_files_only,
|
||||
video_backend=video_backend,
|
||||
)
|
||||
for repo_id in repo_ids
|
||||
|
||||
@@ -13,10 +13,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
import textwrap
|
||||
from collections.abc import Iterator
|
||||
from itertools import accumulate
|
||||
from pathlib import Path
|
||||
@@ -31,9 +31,11 @@ import pyarrow.compute as pc
|
||||
import torch
|
||||
from datasets.table import embed_table_storage
|
||||
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
||||
from packaging import version
|
||||
from PIL import Image as PILImage
|
||||
from torchvision import transforms
|
||||
|
||||
from lerobot.common.datasets.backward_compatibility import V21_MESSAGE, BackwardCompatibilityError
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
|
||||
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
|
||||
@@ -200,7 +202,7 @@ def write_task(task_index: int, task: dict, local_dir: Path):
|
||||
append_jsonlines(task_dict, local_dir / TASKS_PATH)
|
||||
|
||||
|
||||
def load_tasks(local_dir: Path) -> dict:
|
||||
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||
tasks = load_jsonlines(local_dir / TASKS_PATH)
|
||||
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
|
||||
@@ -231,7 +233,9 @@ def load_episodes_stats(local_dir: Path) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def backward_compatible_episodes_stats(stats, episodes: list[int]) -> dict[str, dict[str, np.ndarray]]:
|
||||
def backward_compatible_episodes_stats(
|
||||
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
|
||||
) -> dict[str, dict[str, np.ndarray]]:
|
||||
return {ep_idx: stats for ep_idx in episodes}
|
||||
|
||||
|
||||
@@ -265,73 +269,38 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
return items_dict
|
||||
|
||||
|
||||
def _get_major_minor(version: str) -> tuple[int]:
|
||||
split = version.strip("v").split(".")
|
||||
return int(split[0]), int(split[1])
|
||||
|
||||
|
||||
class BackwardCompatibilityError(Exception):
|
||||
def __init__(self, repo_id, version):
|
||||
message = textwrap.dedent(f"""
|
||||
BackwardCompatibilityError: The dataset you requested ({repo_id}) is in {version} format.
|
||||
|
||||
We introduced a new format since v2.0 which is not backward compatible with v1.x.
|
||||
Please, use our conversion script. Modify the following command with your own task description:
|
||||
```
|
||||
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
|
||||
--repo-id {repo_id} \\
|
||||
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
|
||||
```
|
||||
|
||||
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.",
|
||||
"Insert the peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.",
|
||||
"Open the top cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped target.",
|
||||
"Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the sweatshirt.", ...
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
""")
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def check_version_compatibility(
|
||||
repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True
|
||||
) -> None:
|
||||
current_major, _ = _get_major_minor(current_version)
|
||||
major_to_check, _ = _get_major_minor(version_to_check)
|
||||
if major_to_check < current_major and enforce_breaking_major:
|
||||
raise BackwardCompatibilityError(repo_id, version_to_check)
|
||||
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
|
||||
logging.warning(
|
||||
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
|
||||
codebase. The current codebase version is {current_version}. You should be fine since
|
||||
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
|
||||
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
|
||||
)
|
||||
v_check = version.parse(version_to_check)
|
||||
v_current = version.parse(current_version)
|
||||
if v_check.major < v_current.major and enforce_breaking_major:
|
||||
raise BackwardCompatibilityError(repo_id, v_check)
|
||||
elif v_check.minor < v_current.minor:
|
||||
logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=version_to_check))
|
||||
|
||||
|
||||
def get_hub_safe_version(repo_id: str, version: str) -> str:
|
||||
def get_repo_versions(repo_id: str) -> list[version.Version]:
|
||||
"""Returns available valid versions (branches and tags) on given repo."""
|
||||
api = HfApi()
|
||||
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||
branches = [b.name for b in dataset_info.branches]
|
||||
if version not in branches:
|
||||
num_version = float(version.strip("v"))
|
||||
hub_num_versions = [float(v.strip("v")) for v in branches if v.startswith("v")]
|
||||
if num_version >= 2.0 and all(v < 2.0 for v in hub_num_versions):
|
||||
raise BackwardCompatibilityError(repo_id, version)
|
||||
repo_refs = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||
repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags]
|
||||
repo_versions = []
|
||||
for ref in repo_refs:
|
||||
with contextlib.suppress(version.InvalidVersion):
|
||||
repo_versions.append(version.parse(ref))
|
||||
|
||||
logging.warning(
|
||||
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
|
||||
codebase. The following versions are available: {branches}.
|
||||
The requested version ('{version}') is not found. You should be fine since
|
||||
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
|
||||
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
|
||||
)
|
||||
if "main" not in branches:
|
||||
raise ValueError(f"Version 'main' not found on {repo_id}")
|
||||
return "main"
|
||||
else:
|
||||
return version
|
||||
return repo_versions
|
||||
|
||||
|
||||
def get_safe_revision(repo_id: str, revision: str) -> str:
|
||||
"""Returns the version if available on repo, otherwise return the latest available."""
|
||||
api = HfApi()
|
||||
if api.revision_exists(repo_id, revision, repo_type="dataset"):
|
||||
return revision
|
||||
|
||||
hub_versions = get_repo_versions(repo_id)
|
||||
return f"v{max(hub_versions)}"
|
||||
|
||||
|
||||
def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
|
||||
@@ -130,7 +130,7 @@ from lerobot.common.datasets.utils import (
|
||||
create_branch,
|
||||
create_lerobot_dataset_card,
|
||||
flatten_dict,
|
||||
get_hub_safe_version,
|
||||
get_safe_revision,
|
||||
load_json,
|
||||
unflatten_dict,
|
||||
write_json,
|
||||
@@ -443,7 +443,7 @@ def convert_dataset(
|
||||
test_branch: str | None = None,
|
||||
**card_kwargs,
|
||||
):
|
||||
v1 = get_hub_safe_version(repo_id, V16)
|
||||
v1 = get_safe_revision(repo_id, V16)
|
||||
v1x_dir = local_dir / V16 / repo_id
|
||||
v20_dir = local_dir / V20 / repo_id
|
||||
v1x_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.1.
|
||||
"""
|
||||
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import convert_dataset
|
||||
|
||||
LOCAL_DIR = Path("data/")
|
||||
|
||||
|
||||
def batch_convert():
|
||||
status = {}
|
||||
logfile = LOCAL_DIR / "conversion_log_v21.txt"
|
||||
for num, repo_id in available_datasets:
|
||||
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
|
||||
print("---------------------------------------------------------")
|
||||
try:
|
||||
convert_dataset(repo_id)
|
||||
status = f"{repo_id}: success."
|
||||
with open(logfile, "a") as file:
|
||||
file.write(status + "\n")
|
||||
except Exception:
|
||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||
with open(logfile, "a") as file:
|
||||
file.write(status + "\n")
|
||||
continue
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_convert()
|
||||
@@ -1,10 +1,12 @@
|
||||
"""
|
||||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
|
||||
2.1. It performs the following:
|
||||
2.1. It will:
|
||||
|
||||
- Generates per-episodes stats and writes them in `episodes_stats.jsonl`
|
||||
- Check consistency between these new stats and the old ones.
|
||||
- Removes the deprecated `stats.json` (by default)
|
||||
- Updates codebase_version in `info.json`
|
||||
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
|
||||
|
||||
Usage:
|
||||
|
||||
@@ -14,9 +16,9 @@ python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py \
|
||||
```
|
||||
|
||||
"""
|
||||
# TODO(rcadene, aliberts): ensure this script works for any other changes for the final v2.1
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
@@ -24,14 +26,27 @@ from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDat
|
||||
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
|
||||
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
|
||||
|
||||
V20 = "v2.0"
|
||||
V21 = "v2.1"
|
||||
|
||||
def main(
|
||||
|
||||
class SuppressWarnings:
|
||||
def __enter__(self):
|
||||
self.previous_level = logging.getLogger().getEffectiveLevel()
|
||||
logging.getLogger().setLevel(logging.ERROR)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
logging.getLogger().setLevel(self.previous_level)
|
||||
|
||||
|
||||
def convert_dataset(
|
||||
repo_id: str,
|
||||
test_branch: str | None = None,
|
||||
delete_old_stats: bool = False,
|
||||
branch: str | None = None,
|
||||
num_workers: int = 4,
|
||||
):
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
with SuppressWarnings():
|
||||
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
|
||||
|
||||
if (dataset.root / EPISODES_STATS_PATH).is_file():
|
||||
raise FileExistsError("episodes_stats.jsonl already exists.")
|
||||
|
||||
@@ -42,18 +57,21 @@ def main(
|
||||
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
|
||||
write_info(dataset.meta.info, dataset.root)
|
||||
|
||||
dataset.push_to_hub(branch=test_branch, create_card=False, allow_patterns="meta/")
|
||||
dataset.push_to_hub(branch=branch, allow_patterns="meta/")
|
||||
|
||||
if delete_old_stats:
|
||||
if (dataset.root / STATS_PATH).is_file:
|
||||
(dataset.root / STATS_PATH).unlink()
|
||||
hub_api = HfApi()
|
||||
if hub_api.file_exists(
|
||||
STATS_PATH, repo_id=dataset.repo_id, revision=test_branch, repo_type="dataset"
|
||||
):
|
||||
hub_api.delete_file(
|
||||
STATS_PATH, repo_id=dataset.repo_id, revision=test_branch, repo_type="dataset"
|
||||
)
|
||||
# delete old stats.json file
|
||||
if (dataset.root / STATS_PATH).is_file:
|
||||
(dataset.root / STATS_PATH).unlink()
|
||||
|
||||
hub_api = HfApi()
|
||||
if hub_api.file_exists(
|
||||
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
|
||||
):
|
||||
hub_api.delete_file(
|
||||
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
|
||||
)
|
||||
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -65,16 +83,10 @@ if __name__ == "__main__":
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-branch",
|
||||
"--branch",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Repo branch to test your conversion first (e.g. 'v2.0.test')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delete-old-stats",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Delete the deprecated `stats.json`",
|
||||
help="Repo branch to push your dataset (defaults to the main branch)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
@@ -84,4 +96,4 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(**vars(args))
|
||||
convert_dataset(**vars(args))
|
||||
|
||||
@@ -81,9 +81,6 @@ class RecordControlConfig(ControlConfig):
|
||||
play_sounds: bool = True
|
||||
# Resume recording on an existing dataset.
|
||||
resume: bool = False
|
||||
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
|
||||
# Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.
|
||||
local_files_only: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
@@ -128,9 +125,6 @@ class ReplayControlConfig(ControlConfig):
|
||||
fps: int | None = None
|
||||
# Use vocal synthesis to read events.
|
||||
play_sounds: bool = True
|
||||
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
|
||||
# Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.
|
||||
local_files_only: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user