add write_stats, changes names, add some typing
This commit is contained in:
@@ -48,7 +48,7 @@ This dataset was created using [LeRobot](https://github.com/huggingface/lerobot)
|
||||
"""
|
||||
|
||||
|
||||
def flatten_dict(d, parent_key="", sep="/"):
|
||||
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
|
||||
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
||||
|
||||
For example:
|
||||
@@ -67,7 +67,7 @@ def flatten_dict(d, parent_key="", sep="/"):
|
||||
return dict(items)
|
||||
|
||||
|
||||
def unflatten_dict(d, sep="/"):
|
||||
def unflatten_dict(d: dict, sep: str = "/") -> dict:
|
||||
outdict = {}
|
||||
for key, value in d.items():
|
||||
parts = key.split(sep)
|
||||
@@ -92,6 +92,12 @@ def append_jsonl(data: dict, fpath: Path) -> None:
|
||||
writer.write(data)
|
||||
|
||||
|
||||
def write_stats(stats: dict[str, torch.Tensor | dict], fpath: Path) -> None:
|
||||
serialized_stats = {key: value.tolist() for key, value in flatten_dict(stats).items()}
|
||||
serialized_stats = unflatten_dict(serialized_stats)
|
||||
write_json(serialized_stats, fpath)
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||
|
||||
Reference in New Issue
Block a user