Rework LeRobotDataset.__init__
This commit is contained in:
@@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import re
|
||||
import warnings
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
@@ -22,10 +21,9 @@ from typing import Dict
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
from datasets import load_dataset, load_from_disk
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import DatasetCard, HfApi, hf_hub_download, snapshot_download
|
||||
from PIL import Image as PILImage
|
||||
from safetensors.torch import load_file
|
||||
from torchvision import transforms
|
||||
|
||||
DATASET_CARD_TEMPLATE = """
|
||||
@@ -96,7 +94,14 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
|
||||
|
||||
@cache
|
||||
def get_hf_dataset_safe_version(repo_id: str, version: str) -> str:
|
||||
def get_hub_safe_version(repo_id: str, version: str) -> str:
|
||||
num_version = float(version.strip("v"))
|
||||
if num_version < 2:
|
||||
raise ValueError(
|
||||
f"""The dataset you requested ({repo_id}) is in {version} format. We introduced a new
|
||||
format with v2.0 that is not backward compatible. Please use our conversion script
|
||||
first (convert_dataset_16_to_20.py) to convert your dataset to this new format."""
|
||||
)
|
||||
api = HfApi()
|
||||
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||
branches = [b.name for b in dataset_info.branches]
|
||||
@@ -116,56 +121,27 @@ def get_hf_dataset_safe_version(repo_id: str, version: str) -> str:
|
||||
return version
|
||||
|
||||
|
||||
def load_hf_dataset(repo_id: str, version: str, root: Path, split: str) -> datasets.Dataset:
|
||||
def load_hf_dataset(
|
||||
local_dir: Path,
|
||||
data_path: str,
|
||||
total_episodes: int,
|
||||
episodes: list[int] | None = None,
|
||||
split: str = "train",
|
||||
) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
if root is not None:
|
||||
hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
|
||||
# TODO(rcadene): clean this which enables getting a subset of dataset
|
||||
if split != "train":
|
||||
if "%" in split:
|
||||
raise NotImplementedError(f"We dont support splitting based on percentage for now ({split}).")
|
||||
match_from = re.search(r"train\[(\d+):\]", split)
|
||||
match_to = re.search(r"train\[:(\d+)\]", split)
|
||||
if match_from:
|
||||
from_frame_index = int(match_from.group(1))
|
||||
hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset)))
|
||||
elif match_to:
|
||||
to_frame_index = int(match_to.group(1))
|
||||
hf_dataset = hf_dataset.select(range(to_frame_index))
|
||||
else:
|
||||
raise ValueError(
|
||||
f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"'
|
||||
)
|
||||
if episodes is None:
|
||||
path = str(local_dir / "data")
|
||||
hf_dataset = load_dataset("parquet", data_dir=path, split=split)
|
||||
else:
|
||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||
hf_dataset = load_dataset(repo_id, revision=safe_version, split=split)
|
||||
files = [data_path.format(episode_index=ep_idx, total_episodes=total_episodes) for ep_idx in episodes]
|
||||
files = [str(local_dir / fpath) for fpath in files]
|
||||
hf_dataset = load_dataset("parquet", data_files=files, split=split)
|
||||
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]:
|
||||
"""episode_data_index contains the range of indices for each episode
|
||||
|
||||
Example:
|
||||
```python
|
||||
from_id = episode_data_index["from"][episode_id].item()
|
||||
to_id = episode_data_index["to"][episode_id].item()
|
||||
episode_frames = [dataset[i] for i in range(from_id, to_id)]
|
||||
```
|
||||
"""
|
||||
if root is not None:
|
||||
path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors"
|
||||
else:
|
||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||
path = hf_hub_download(
|
||||
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=safe_version
|
||||
)
|
||||
|
||||
return load_file(path)
|
||||
|
||||
|
||||
def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
|
||||
def load_stats(repo_id: str, version: str, local_dir: Path) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
|
||||
|
||||
Example:
|
||||
@@ -173,47 +149,84 @@ def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
|
||||
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
|
||||
```
|
||||
"""
|
||||
if root is not None:
|
||||
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
|
||||
else:
|
||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||
path = hf_hub_download(
|
||||
repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=safe_version
|
||||
)
|
||||
fpath = hf_hub_download(
|
||||
repo_id, filename="meta/stats.json", local_dir=local_dir, repo_type="dataset", revision=version
|
||||
)
|
||||
with open(fpath) as f:
|
||||
stats = json.load(f)
|
||||
|
||||
stats = load_file(path)
|
||||
stats = flatten_dict(stats)
|
||||
stats = {key: torch.tensor(value) for key, value in stats.items()}
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
def load_info(repo_id, version, root) -> dict:
|
||||
"""info contains useful information regarding the dataset that are not stored elsewhere
|
||||
def load_info(repo_id: str, version: str, local_dir: Path) -> dict:
|
||||
"""info contains structural information about the dataset. It should be the reference and
|
||||
act as the 'source of thruth' for what's inside the dataset.
|
||||
|
||||
Example:
|
||||
```python
|
||||
print("frame per second used to collect the video", info["fps"])
|
||||
```
|
||||
"""
|
||||
if root is not None:
|
||||
path = Path(root) / repo_id / "meta_data" / "info.json"
|
||||
else:
|
||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=safe_version)
|
||||
|
||||
with open(path) as f:
|
||||
info = json.load(f)
|
||||
return info
|
||||
fpath = hf_hub_download(
|
||||
repo_id, filename="meta/info.json", local_dir=local_dir, repo_type="dataset", revision=version
|
||||
)
|
||||
with open(fpath) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def load_videos(repo_id, version, root) -> Path:
|
||||
if root is not None:
|
||||
path = Path(root) / repo_id / "videos"
|
||||
else:
|
||||
# TODO(rcadene): we download the whole repo here. see if we can avoid this
|
||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||
repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=safe_version)
|
||||
path = Path(repo_dir) / "videos"
|
||||
def load_tasks(repo_id: str, version: str, local_dir: Path) -> dict:
|
||||
"""tasks contains all the tasks of the dataset, indexed by their task_index.
|
||||
|
||||
return path
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"0": "Pick the Lego block and drop it in the box on the right."
|
||||
}
|
||||
```
|
||||
"""
|
||||
fpath = hf_hub_download(
|
||||
repo_id, filename="meta/tasks.json", local_dir=local_dir, repo_type="dataset", revision=version
|
||||
)
|
||||
with open(fpath) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def download_episodes(
|
||||
repo_id: str,
|
||||
version: str,
|
||||
local_dir: Path,
|
||||
data_path: str,
|
||||
video_keys: list,
|
||||
total_episodes: int,
|
||||
episodes: list[int] | None = None,
|
||||
videos_path: str | None = None,
|
||||
) -> None:
|
||||
"""Downloads the dataset from the given 'repo_id' at the provided 'version'. If 'episodes' is given, this
|
||||
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
||||
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
||||
in 'local_dir', they won't be downloaded again.
|
||||
|
||||
Note: Currently, if you're running this code offline but you already have the files in 'local_dir',
|
||||
snapshot_download will still fail. This behavior will be fixed in an upcoming update of huggingface_hub.
|
||||
"""
|
||||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||
files = None
|
||||
if episodes is not None:
|
||||
files = [data_path.format(episode_index=ep_idx, total_episodes=total_episodes) for ep_idx in episodes]
|
||||
if len(video_keys) > 0:
|
||||
video_files = [
|
||||
videos_path.format(video_key=vid_key, episode_index=ep_idx)
|
||||
for vid_key in video_keys
|
||||
for ep_idx in episodes
|
||||
]
|
||||
files += video_files
|
||||
|
||||
snapshot_download(
|
||||
repo_id, repo_type="dataset", revision=version, local_dir=local_dir, allow_patterns=files
|
||||
)
|
||||
|
||||
|
||||
def load_previous_and_future_frames(
|
||||
|
||||
Reference in New Issue
Block a user