Add extra info to dataset card, various fixes from Remi's review
This commit is contained in:
@@ -22,6 +22,7 @@ from typing import Any
|
||||
|
||||
import datasets
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import pyarrow.compute as pc
|
||||
import torch
|
||||
from datasets.table import embed_table_storage
|
||||
@@ -91,6 +92,11 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
|
||||
return outdict
|
||||
|
||||
|
||||
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
||||
serialized_dict = {key: value.tolist() for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(serialized_dict)
|
||||
|
||||
|
||||
def write_parquet(dataset: datasets.Dataset, fpath: Path) -> None:
|
||||
# Embed image bytes into the table before saving to parquet
|
||||
format = dataset.format
|
||||
@@ -128,12 +134,6 @@ def append_jsonlines(data: dict, fpath: Path) -> None:
|
||||
writer.write(data)
|
||||
|
||||
|
||||
def write_stats(stats: dict[str, torch.Tensor | dict], fpath: Path) -> None:
|
||||
serialized_stats = {key: value.tolist() for key, value in flatten_dict(stats).items()}
|
||||
serialized_stats = unflatten_dict(serialized_stats)
|
||||
write_json(serialized_stats, fpath)
|
||||
|
||||
|
||||
def load_info(local_dir: Path) -> dict:
|
||||
return load_json(local_dir / INFO_PATH)
|
||||
|
||||
@@ -153,6 +153,16 @@ def load_episodes(local_dir: Path) -> dict:
|
||||
return load_jsonlines(local_dir / EPISODES_PATH)
|
||||
|
||||
|
||||
def load_image_as_numpy(fpath: str | Path, dtype="float32", channel_first: bool = True) -> np.ndarray:
|
||||
img = PILImage.open(fpath).convert("RGB")
|
||||
img_array = np.array(img, dtype=dtype)
|
||||
if channel_first: # (H, W, C) -> (C, H, W)
|
||||
img_array = np.transpose(img_array, (2, 0, 1))
|
||||
if "float" in dtype:
|
||||
img_array /= 255.0
|
||||
return img_array
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||
@@ -331,7 +341,7 @@ def check_timestamps_sync(
|
||||
within_tolerance = torch.abs(diffs - 1 / fps) <= tolerance_s
|
||||
|
||||
# We mask differences between the timestamp at the end of an episode
|
||||
# and the one the start of the next episode since these are expected
|
||||
# and the one at the start of the next episode since these are expected
|
||||
# to be outside tolerance.
|
||||
mask = torch.ones(len(diffs), dtype=torch.bool)
|
||||
ignored_diffs = episode_data_index["to"][:-1] - 1
|
||||
@@ -433,7 +443,12 @@ def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None
|
||||
|
||||
|
||||
def create_lerobot_dataset_card(
|
||||
tags: list | None = None, text: str | None = None, info: dict | None = None
|
||||
tags: list | None = None,
|
||||
text: str | None = None,
|
||||
info: dict | None = None,
|
||||
license: str | None = None,
|
||||
citation: str | None = None,
|
||||
arxiv: str | None = None,
|
||||
) -> DatasetCard:
|
||||
card = DatasetCard(DATASET_CARD_TEMPLATE)
|
||||
card.data.configs = [
|
||||
@@ -444,11 +459,19 @@ def create_lerobot_dataset_card(
|
||||
]
|
||||
card.data.task_categories = ["robotics"]
|
||||
card.data.tags = ["LeRobot"]
|
||||
if tags is not None:
|
||||
if license:
|
||||
card.data.license = license
|
||||
if tags:
|
||||
card.data.tags += tags
|
||||
if text is not None:
|
||||
if text:
|
||||
card.text += f"{text}\n"
|
||||
if info is not None:
|
||||
if info:
|
||||
card.text += "## Info\n"
|
||||
card.text += "[meta/info.json](meta/info.json)\n"
|
||||
card.text += f"```json\n{json.dumps(info, indent=4)}\n```"
|
||||
if citation:
|
||||
card.text += "## Citation\n"
|
||||
card.text += f"```\n{citation}\n```\n"
|
||||
if arxiv:
|
||||
card.data.arxiv = arxiv
|
||||
return card
|
||||
|
||||
Reference in New Issue
Block a user