[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-04 13:38:47 +00:00
committed by AdilZouitine
parent 76df8a31b3
commit 38f5fa4523
79 changed files with 2782 additions and 788 deletions

View File

@@ -23,7 +23,11 @@ import PIL.Image
import pytest
import torch
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.lerobot_dataset import (
CODEBASE_VERSION,
LeRobotDataset,
LeRobotDatasetMetadata,
)
from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_FEATURES,
@@ -54,7 +58,9 @@ def get_task_index(task_dicts: dict, task: str) -> int:
@pytest.fixture(scope="session")
def img_tensor_factory():
def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor:
def _create_img_tensor(
height=100, width=100, channels=3, dtype=torch.float32
) -> torch.Tensor:
return torch.rand((channels, height, width), dtype=dtype)
return _create_img_tensor
@@ -62,10 +68,14 @@ def img_tensor_factory():
@pytest.fixture(scope="session")
def img_array_factory():
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray:
def _create_img_array(
height=100, width=100, channels=3, dtype=np.uint8
) -> np.ndarray:
if np.issubdtype(dtype, np.unsignedinteger):
# Int array in [0, 255] range
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
img_array = np.random.randint(
0, 256, size=(height, width, channels), dtype=dtype
)
elif np.issubdtype(dtype, np.floating):
# Float array in [0, 1] range
img_array = np.random.rand(height, width, channels).astype(dtype)
@@ -94,10 +104,13 @@ def features_factory():
) -> dict:
if use_videos:
camera_ft = {
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items()
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO}
for key, ft in camera_features.items()
}
else:
camera_ft = {key: {"dtype": "image", **ft} for key, ft in camera_features.items()}
camera_ft = {
key: {"dtype": "image", **ft} for key, ft in camera_features.items()
}
return {
**motor_features,
**camera_ft,
@@ -215,7 +228,9 @@ def episodes_factory(tasks_factory):
if total_episodes <= 0 or total_frames <= 0:
raise ValueError("num_episodes and total_length must be positive integers.")
if total_frames < total_episodes:
raise ValueError("total_length must be greater than or equal to num_episodes.")
raise ValueError(
"total_length must be greater than or equal to num_episodes."
)
if not tasks:
min_tasks = 2 if multi_task else 1
@@ -223,10 +238,14 @@ def episodes_factory(tasks_factory):
tasks = tasks_factory(total_tasks)
if total_episodes < len(tasks) and not multi_task:
raise ValueError("The number of tasks should be less than the number of episodes.")
raise ValueError(
"The number of tasks should be less than the number of episodes."
)
# Generate random lengths that sum up to total_length
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
lengths = np.random.multinomial(
total_frames, [1 / total_episodes] * total_episodes
).tolist()
tasks_list = [task_dict["task"] for task_dict in tasks.values()]
num_tasks_available = len(tasks_list)
@@ -234,9 +253,13 @@ def episodes_factory(tasks_factory):
episodes = {}
remaining_tasks = tasks_list.copy()
for ep_idx in range(total_episodes):
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
num_tasks_in_episode = (
random.randint(1, min(3, num_tasks_available)) if multi_task else 1
)
tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)))
episode_tasks = random.sample(
tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample))
)
if remaining_tasks:
for task in episode_tasks:
remaining_tasks.remove(task)
@@ -253,7 +276,9 @@ def episodes_factory(tasks_factory):
@pytest.fixture(scope="session")
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
def hf_dataset_factory(
features_factory, tasks_factory, episodes_factory, img_array_factory
):
def _create_hf_dataset(
features: dict | None = None,
tasks: list[dict] | None = None,
@@ -275,10 +300,15 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
episode_index_col = np.concatenate(
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
(
episode_index_col,
np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int),
)
)
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
task_index = np.concatenate(
(task_index, np.full(ep_dict["length"], ep_task_index, dtype=int))
)
index_col = np.arange(len(episode_index_col))
@@ -290,7 +320,9 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
for _ in range(len(index_col))
]
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"])
robot_cols[key] = np.random.random(
(len(index_col), ft["shape"][0])
).astype(ft["dtype"])
hf_features = get_hf_features_from_features(features)
dataset = datasets.Dataset.from_dict(
@@ -340,7 +372,9 @@ def lerobot_dataset_metadata_factory(
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
tasks=tasks,
)
mock_snapshot_download = mock_snapshot_download_factory(
@@ -392,7 +426,9 @@ def lerobot_dataset_factory(
) -> LeRobotDataset:
if not info:
info = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
total_episodes=total_episodes,
total_frames=total_frames,
total_tasks=total_tasks,
)
if not stats:
stats = stats_factory(features=info["features"])
@@ -408,7 +444,9 @@ def lerobot_dataset_factory(
multi_task=multi_task,
)
if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"])
hf_dataset = hf_dataset_factory(
tasks=tasks, episodes=episode_dicts, fps=info["fps"]
)
mock_snapshot_download = mock_snapshot_download_factory(
info=info,