pre-commit run --all-files

This commit is contained in:
Remi Cadene
2025-04-21 09:34:19 +02:00
parent 5a6ea09248
commit 4acf99f622
7 changed files with 85 additions and 33 deletions

View File

@@ -24,7 +24,7 @@ from collections.abc import Iterator
from pathlib import Path
from pprint import pformat
from types import SimpleNamespace
from typing import Any, Tuple
from typing import Any
import datasets
import numpy as np
@@ -47,7 +47,7 @@ from lerobot.common.datasets.backward_compatibility import (
)
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
from lerobot.configs.types import FeatureType, PolicyFeature
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
DEFAULT_FILE_SIZE_IN_MB = 500.0 # Max size per file
@@ -249,34 +249,41 @@ def load_json(fpath: Path) -> Any:
with open(fpath) as f:
return json.load(f)
def write_json(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with open(fpath, "w") as f:
json.dump(data, f, indent=4, ensure_ascii=False)
def write_info(info: dict, local_dir: Path):
write_json(info, local_dir / INFO_PATH)
def load_info(local_dir: Path) -> dict:
info = load_json(local_dir / INFO_PATH)
for ft in info["features"].values():
ft["shape"] = tuple(ft["shape"])
return info
def write_stats(stats: dict, local_dir: Path):
serialized_stats = serialize_dict(stats)
write_json(serialized_stats, local_dir / STATS_PATH)
def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
if not (local_dir / STATS_PATH).exists():
return None
stats = load_json(local_dir / STATS_PATH)
return cast_stats_to_numpy(stats)
def write_hf_dataset(hf_dataset: Dataset, local_dir: Path):
if get_hf_dataset_size_in_mb(hf_dataset) > DEFAULT_FILE_SIZE_IN_MB:
raise NotImplementedError("Contact a maintainer.")
@@ -292,7 +299,6 @@ def write_tasks(tasks: pandas.DataFrame, local_dir: Path):
tasks.to_parquet(path)
def load_tasks(local_dir: Path):
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
return tasks