Add file paths
This commit is contained in:
@@ -23,7 +23,6 @@ from typing import Dict
|
||||
import datasets
|
||||
import jsonlines
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import DatasetCard, HfApi
|
||||
from PIL import Image as PILImage
|
||||
from torchvision import transforms
|
||||
@@ -87,15 +86,6 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
if isinstance(first_item, PILImage.Image):
|
||||
to_tensor = transforms.ToTensor()
|
||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||
# TODO(aliberts): remove this part as we'll be using task_index
|
||||
elif isinstance(first_item, str):
|
||||
# TODO (michel-aractingi): add str2embedding via language tokenizer
|
||||
# For now we leave this part up to the user to choose how to address
|
||||
# language conditioned tasks
|
||||
pass
|
||||
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
|
||||
# video frame will be processed downstream
|
||||
pass
|
||||
elif first_item is None:
|
||||
pass
|
||||
else:
|
||||
@@ -130,32 +120,12 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) ->
|
||||
return version
|
||||
|
||||
|
||||
def load_hf_dataset(
|
||||
local_dir: Path,
|
||||
data_path: str,
|
||||
total_episodes: int,
|
||||
episodes: list[int] | None = None,
|
||||
split: str = "train",
|
||||
) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
if episodes is None:
|
||||
path = str(local_dir / "data")
|
||||
hf_dataset = load_dataset("parquet", data_dir=path, split=split)
|
||||
else:
|
||||
files = [data_path.format(episode_index=ep_idx, total_episodes=total_episodes) for ep_idx in episodes]
|
||||
files = [str(local_dir / fpath) for fpath in files]
|
||||
hf_dataset = load_dataset("parquet", data_files=files, split=split)
|
||||
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def load_metadata(local_dir: Path) -> tuple[dict | list]:
|
||||
"""Loads metadata files from a dataset."""
|
||||
info_path = local_dir / "meta/info.jsonl"
|
||||
info_path = local_dir / "meta/info.json"
|
||||
episodes_path = local_dir / "meta/episodes.jsonl"
|
||||
stats_path = local_dir / "meta/stats.json"
|
||||
tasks_path = local_dir / "meta/tasks.json"
|
||||
tasks_path = local_dir / "meta/tasks.jsonl"
|
||||
|
||||
with open(info_path) as f:
|
||||
info = json.load(f)
|
||||
@@ -499,12 +469,17 @@ def create_branch(repo_id, *, branch: str, repo_type: str | None = None):
|
||||
api.create_branch(repo_id, repo_type=repo_type, branch=branch)
|
||||
|
||||
|
||||
def create_lerobot_dataset_card(tags: list | None = None, text: str | None = None) -> DatasetCard:
|
||||
def create_lerobot_dataset_card(
|
||||
tags: list | None = None, text: str | None = None, info: dict | None = None
|
||||
) -> DatasetCard:
|
||||
card = DatasetCard(DATASET_CARD_TEMPLATE)
|
||||
card.data.task_categories = ["robotics"]
|
||||
card.data.tags = ["LeRobot"]
|
||||
if tags is not None:
|
||||
card.data.tags += tags
|
||||
if text is not None:
|
||||
card.text += text
|
||||
card.text += f"{text}\n"
|
||||
if info is not None:
|
||||
card.text += "[meta/info.json](meta/info.json)\n"
|
||||
card.text += f"```json\n{json.dumps(info, indent=4)}\n```"
|
||||
return card
|
||||
|
||||
Reference in New Issue
Block a user