add write_stats, changes names, add some typing

This commit is contained in:
Simon Alibert
2024-10-23 11:38:07 +02:00
parent fb73cdb9a4
commit a2a8538ac9
5 changed files with 33 additions and 26 deletions

View File

@@ -48,7 +48,7 @@ This dataset was created using [LeRobot](https://github.com/huggingface/lerobot)
"""
def flatten_dict(d, parent_key="", sep="/"):
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
For example:
@@ -67,7 +67,7 @@ def flatten_dict(d, parent_key="", sep="/"):
return dict(items)
def unflatten_dict(d, sep="/"):
def unflatten_dict(d: dict, sep: str = "/") -> dict:
outdict = {}
for key, value in d.items():
parts = key.split(sep)
@@ -92,6 +92,12 @@ def append_jsonl(data: dict, fpath: Path) -> None:
writer.write(data)
def write_stats(stats: dict[str, torch.Tensor | dict], fpath: Path) -> None:
serialized_stats = {key: value.tolist() for key, value in flatten_dict(stats).items()}
serialized_stats = unflatten_dict(serialized_stats)
write_json(serialized_stats, fpath)
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. Importantly, images are converted from PIL, which corresponds to