Add extra info to dataset card, various fixes from Remi's review

This commit is contained in:
Simon Alibert
2024-11-18 17:50:13 +01:00
parent 4d15861872
commit a91b7c6163
5 changed files with 250 additions and 82 deletions

View File

@@ -22,6 +22,7 @@ from typing import Any
import datasets
import jsonlines
import numpy as np
import pyarrow.compute as pc
import torch
from datasets.table import embed_table_storage
@@ -91,6 +92,11 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
return outdict
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
serialized_dict = {key: value.tolist() for key, value in flatten_dict(stats).items()}
return unflatten_dict(serialized_dict)
def write_parquet(dataset: datasets.Dataset, fpath: Path) -> None:
# Embed image bytes into the table before saving to parquet
format = dataset.format
@@ -128,12 +134,6 @@ def append_jsonlines(data: dict, fpath: Path) -> None:
writer.write(data)
def write_stats(stats: dict[str, torch.Tensor | dict], fpath: Path) -> None:
serialized_stats = {key: value.tolist() for key, value in flatten_dict(stats).items()}
serialized_stats = unflatten_dict(serialized_stats)
write_json(serialized_stats, fpath)
def load_info(local_dir: Path) -> dict:
return load_json(local_dir / INFO_PATH)
@@ -153,6 +153,16 @@ def load_episodes(local_dir: Path) -> dict:
return load_jsonlines(local_dir / EPISODES_PATH)
def load_image_as_numpy(fpath: str | Path, dtype="float32", channel_first: bool = True) -> np.ndarray:
img = PILImage.open(fpath).convert("RGB")
img_array = np.array(img, dtype=dtype)
if channel_first: # (H, W, C) -> (C, H, W)
img_array = np.transpose(img_array, (2, 0, 1))
if "float" in dtype:
img_array /= 255.0
return img_array
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. Importantly, images are converted from PIL, which corresponds to
@@ -331,7 +341,7 @@ def check_timestamps_sync(
within_tolerance = torch.abs(diffs - 1 / fps) <= tolerance_s
# We mask differences between the timestamp at the end of an episode
# and the one the start of the next episode since these are expected
# and the one at the start of the next episode since these are expected
# to be outside tolerance.
mask = torch.ones(len(diffs), dtype=torch.bool)
ignored_diffs = episode_data_index["to"][:-1] - 1
@@ -433,7 +443,12 @@ def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None
def create_lerobot_dataset_card(
tags: list | None = None, text: str | None = None, info: dict | None = None
tags: list | None = None,
text: str | None = None,
info: dict | None = None,
license: str | None = None,
citation: str | None = None,
arxiv: str | None = None,
) -> DatasetCard:
card = DatasetCard(DATASET_CARD_TEMPLATE)
card.data.configs = [
@@ -444,11 +459,19 @@ def create_lerobot_dataset_card(
]
card.data.task_categories = ["robotics"]
card.data.tags = ["LeRobot"]
if tags is not None:
if license:
card.data.license = license
if tags:
card.data.tags += tags
if text is not None:
if text:
card.text += f"{text}\n"
if info is not None:
if info:
card.text += "## Info\n"
card.text += "[meta/info.json](meta/info.json)\n"
card.text += f"```json\n{json.dumps(info, indent=4)}\n```"
if citation:
card.text += "## Citation\n"
card.text += f"```\n{citation}\n```\n"
if arxiv:
card.data.arxiv = arxiv
return card