Fix datasets missing versions (#318)

This commit is contained in:
Simon Alibert
2024-07-16 23:02:31 +02:00
committed by GitHub
parent 5f5efe7cb9
commit 8865e19c12
12 changed files with 156 additions and 120 deletions

View File

@@ -15,13 +15,15 @@
# limitations under the License.
import json
import re
import warnings
from functools import cache
from pathlib import Path
from typing import Dict
import datasets
import torch
from datasets import load_dataset, load_from_disk
from huggingface_hub import hf_hub_download, snapshot_download
from huggingface_hub import HfApi, hf_hub_download, snapshot_download
from PIL import Image as PILImage
from safetensors.torch import load_file
from torchvision import transforms
@@ -80,7 +82,28 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
return items_dict
def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
@cache
def get_hf_dataset_safe_version(repo_id: str, version: str) -> str:
api = HfApi()
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
branches = [b.name for b in dataset_info.branches]
if version not in branches:
warnings.warn(
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
codebase. The following versions are available: {branches}.
The requested version ('{version}') is not found. You should be fine since
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
stacklevel=1,
)
if "main" not in branches:
raise ValueError(f"Version 'main' not found on {repo_id}")
return "main"
else:
return version
def load_hf_dataset(repo_id: str, version: str, root: Path, split: str) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
if root is not None:
hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
@@ -101,7 +124,9 @@ def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"'
)
else:
hf_dataset = load_dataset(repo_id, revision=version, split=split)
safe_version = get_hf_dataset_safe_version(repo_id, version)
hf_dataset = load_dataset(repo_id, revision=safe_version, split=split)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
@@ -119,8 +144,9 @@ def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]:
if root is not None:
path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors"
else:
safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download(
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=safe_version
)
return load_file(path)
@@ -137,7 +163,10 @@ def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
if root is not None:
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
else:
path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version)
safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download(
repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=safe_version
)
stats = load_file(path)
return unflatten_dict(stats)
@@ -154,7 +183,8 @@ def load_info(repo_id, version, root) -> dict:
if root is not None:
path = Path(root) / repo_id / "meta_data" / "info.json"
else:
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=version)
safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=safe_version)
with open(path) as f:
info = json.load(f)
@@ -166,7 +196,8 @@ def load_videos(repo_id, version, root) -> Path:
path = Path(root) / repo_id / "videos"
else:
# TODO(rcadene): we download the whole repo here. see if we can avoid this
repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=version)
safe_version = get_hf_dataset_safe_version(repo_id, version)
repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=safe_version)
path = Path(repo_dir) / "videos"
return path