Add file paths

This commit is contained in:
Simon Alibert
2024-10-20 14:00:19 +02:00
parent ac3798bd62
commit 9316cf46ef
2 changed files with 60 additions and 53 deletions

View File

@@ -23,7 +23,6 @@ from typing import Dict
import datasets
import jsonlines
import torch
from datasets import load_dataset
from huggingface_hub import DatasetCard, HfApi
from PIL import Image as PILImage
from torchvision import transforms
@@ -87,15 +86,6 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
# TODO(aliberts): remove this part as we'll be using task_index
elif isinstance(first_item, str):
# TODO (michel-aractingi): add str2embedding via language tokenizer
# For now we leave this part up to the user to choose how to address
# language conditioned tasks
pass
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
# video frame will be processed downstream
pass
elif first_item is None:
pass
else:
@@ -130,32 +120,12 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) ->
return version
def load_hf_dataset(
local_dir: Path,
data_path: str,
total_episodes: int,
episodes: list[int] | None = None,
split: str = "train",
) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
if episodes is None:
path = str(local_dir / "data")
hf_dataset = load_dataset("parquet", data_dir=path, split=split)
else:
files = [data_path.format(episode_index=ep_idx, total_episodes=total_episodes) for ep_idx in episodes]
files = [str(local_dir / fpath) for fpath in files]
hf_dataset = load_dataset("parquet", data_files=files, split=split)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def load_metadata(local_dir: Path) -> tuple[dict | list]:
"""Loads metadata files from a dataset."""
info_path = local_dir / "meta/info.jsonl"
info_path = local_dir / "meta/info.json"
episodes_path = local_dir / "meta/episodes.jsonl"
stats_path = local_dir / "meta/stats.json"
tasks_path = local_dir / "meta/tasks.json"
tasks_path = local_dir / "meta/tasks.jsonl"
with open(info_path) as f:
info = json.load(f)
@@ -499,12 +469,17 @@ def create_branch(repo_id, *, branch: str, repo_type: str | None = None):
api.create_branch(repo_id, repo_type=repo_type, branch=branch)
def create_lerobot_dataset_card(tags: list | None = None, text: str | None = None) -> DatasetCard:
def create_lerobot_dataset_card(
tags: list | None = None, text: str | None = None, info: dict | None = None
) -> DatasetCard:
card = DatasetCard(DATASET_CARD_TEMPLATE)
card.data.task_categories = ["robotics"]
card.data.tags = ["LeRobot"]
if tags is not None:
card.data.tags += tags
if text is not None:
card.text += text
card.text += f"{text}\n"
if info is not None:
card.text += "[meta/info.json](meta/info.json)\n"
card.text += f"```json\n{json.dumps(info, indent=4)}\n```"
return card