Rework LeRobotDataset.__init__

This commit is contained in:
Simon Alibert
2024-10-09 14:33:26 +02:00
parent 2d75b93ba0
commit 096824b5ff
2 changed files with 189 additions and 114 deletions

View File

@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
import warnings
from functools import cache
from pathlib import Path
@@ -22,10 +21,9 @@ from typing import Dict
import datasets
import torch
from datasets import load_dataset, load_from_disk
from datasets import load_dataset
from huggingface_hub import DatasetCard, HfApi, hf_hub_download, snapshot_download
from PIL import Image as PILImage
from safetensors.torch import load_file
from torchvision import transforms
DATASET_CARD_TEMPLATE = """
@@ -96,7 +94,14 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
@cache
def get_hf_dataset_safe_version(repo_id: str, version: str) -> str:
def get_hub_safe_version(repo_id: str, version: str) -> str:
num_version = float(version.strip("v"))
if num_version < 2:
raise ValueError(
f"""The dataset you requested ({repo_id}) is in {version} format. We introduced a new
format with v2.0 that is not backward compatible. Please use our conversion script
first (convert_dataset_16_to_20.py) to convert your dataset to this new format."""
)
api = HfApi()
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
branches = [b.name for b in dataset_info.branches]
@@ -116,56 +121,27 @@ def get_hf_dataset_safe_version(repo_id: str, version: str) -> str:
return version
def load_hf_dataset(repo_id: str, version: str, root: Path, split: str) -> datasets.Dataset:
def load_hf_dataset(
local_dir: Path,
data_path: str,
total_episodes: int,
episodes: list[int] | None = None,
split: str = "train",
) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
if root is not None:
hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
# TODO(rcadene): clean this which enables getting a subset of dataset
if split != "train":
if "%" in split:
raise NotImplementedError(f"We dont support splitting based on percentage for now ({split}).")
match_from = re.search(r"train\[(\d+):\]", split)
match_to = re.search(r"train\[:(\d+)\]", split)
if match_from:
from_frame_index = int(match_from.group(1))
hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset)))
elif match_to:
to_frame_index = int(match_to.group(1))
hf_dataset = hf_dataset.select(range(to_frame_index))
else:
raise ValueError(
f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"'
)
if episodes is None:
path = str(local_dir / "data")
hf_dataset = load_dataset("parquet", data_dir=path, split=split)
else:
safe_version = get_hf_dataset_safe_version(repo_id, version)
hf_dataset = load_dataset(repo_id, revision=safe_version, split=split)
files = [data_path.format(episode_index=ep_idx, total_episodes=total_episodes) for ep_idx in episodes]
files = [str(local_dir / fpath) for fpath in files]
hf_dataset = load_dataset("parquet", data_files=files, split=split)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]:
"""episode_data_index contains the range of indices for each episode
Example:
```python
from_id = episode_data_index["from"][episode_id].item()
to_id = episode_data_index["to"][episode_id].item()
episode_frames = [dataset[i] for i in range(from_id, to_id)]
```
"""
if root is not None:
path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors"
else:
safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download(
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=safe_version
)
return load_file(path)
def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
def load_stats(repo_id: str, version: str, local_dir: Path) -> dict[str, dict[str, torch.Tensor]]:
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
Example:
@@ -173,47 +149,84 @@ def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
```
"""
if root is not None:
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
else:
safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download(
repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=safe_version
)
fpath = hf_hub_download(
repo_id, filename="meta/stats.json", local_dir=local_dir, repo_type="dataset", revision=version
)
with open(fpath) as f:
stats = json.load(f)
stats = load_file(path)
stats = flatten_dict(stats)
stats = {key: torch.tensor(value) for key, value in stats.items()}
return unflatten_dict(stats)
def load_info(repo_id, version, root) -> dict:
"""info contains useful information regarding the dataset that are not stored elsewhere
def load_info(repo_id: str, version: str, local_dir: Path) -> dict:
"""info contains structural information about the dataset. It should be the reference and
act as the 'source of thruth' for what's inside the dataset.
Example:
```python
print("frame per second used to collect the video", info["fps"])
```
"""
if root is not None:
path = Path(root) / repo_id / "meta_data" / "info.json"
else:
safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=safe_version)
with open(path) as f:
info = json.load(f)
return info
fpath = hf_hub_download(
repo_id, filename="meta/info.json", local_dir=local_dir, repo_type="dataset", revision=version
)
with open(fpath) as f:
return json.load(f)
def load_videos(repo_id, version, root) -> Path:
if root is not None:
path = Path(root) / repo_id / "videos"
else:
# TODO(rcadene): we download the whole repo here. see if we can avoid this
safe_version = get_hf_dataset_safe_version(repo_id, version)
repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=safe_version)
path = Path(repo_dir) / "videos"
def load_tasks(repo_id: str, version: str, local_dir: Path) -> dict:
"""tasks contains all the tasks of the dataset, indexed by their task_index.
return path
Example:
```json
{
"0": "Pick the Lego block and drop it in the box on the right."
}
```
"""
fpath = hf_hub_download(
repo_id, filename="meta/tasks.json", local_dir=local_dir, repo_type="dataset", revision=version
)
with open(fpath) as f:
return json.load(f)
def download_episodes(
repo_id: str,
version: str,
local_dir: Path,
data_path: str,
video_keys: list,
total_episodes: int,
episodes: list[int] | None = None,
videos_path: str | None = None,
) -> None:
"""Downloads the dataset from the given 'repo_id' at the provided 'version'. If 'episodes' is given, this
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
in 'local_dir', they won't be downloaded again.
Note: Currently, if you're running this code offline but you already have the files in 'local_dir',
snapshot_download will still fail. This behavior will be fixed in an upcoming update of huggingface_hub.
"""
# TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
files = None
if episodes is not None:
files = [data_path.format(episode_index=ep_idx, total_episodes=total_episodes) for ep_idx in episodes]
if len(video_keys) > 0:
video_files = [
videos_path.format(video_key=vid_key, episode_index=ep_idx)
for vid_key in video_keys
for ep_idx in episodes
]
files += video_files
snapshot_download(
repo_id, repo_type="dataset", revision=version, local_dir=local_dir, allow_patterns=files
)
def load_previous_and_future_frames(