[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-04 13:38:47 +00:00
parent d8a1758122
commit 584cad808e
108 changed files with 3894 additions and 1189 deletions

View File

@@ -7,17 +7,39 @@ DUMMY_MOTOR_FEATURES = {
"action": {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
"names": [
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
],
},
"state": {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
"names": [
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
],
},
}
DUMMY_CAMERA_FEATURES = {
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
"laptop": {
"shape": (480, 640, 3),
"names": ["height", "width", "channels"],
"info": None,
},
"phone": {
"shape": (480, 640, 3),
"names": ["height", "width", "channels"],
"info": None,
},
}
DEFAULT_FPS = 30
DUMMY_VIDEO_INFO = {

View File

@@ -8,7 +8,11 @@ import PIL.Image
import pytest
import torch
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.lerobot_dataset import (
CODEBASE_VERSION,
LeRobotDataset,
LeRobotDatasetMetadata,
)
from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_FEATURES,
@@ -35,7 +39,9 @@ def get_task_index(task_dicts: dict, task: str) -> int:
@pytest.fixture(scope="session")
def img_tensor_factory():
def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor:
def _create_img_tensor(
height=100, width=100, channels=3, dtype=torch.float32
) -> torch.Tensor:
return torch.rand((channels, height, width), dtype=dtype)
return _create_img_tensor
@@ -43,10 +49,14 @@ def img_tensor_factory():
@pytest.fixture(scope="session")
def img_array_factory():
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray:
def _create_img_array(
height=100, width=100, channels=3, dtype=np.uint8
) -> np.ndarray:
if np.issubdtype(dtype, np.unsignedinteger):
# Int array in [0, 255] range
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
img_array = np.random.randint(
0, 256, size=(height, width, channels), dtype=dtype
)
elif np.issubdtype(dtype, np.floating):
# Float array in [0, 1] range
img_array = np.random.rand(height, width, channels).astype(dtype)
@@ -75,10 +85,13 @@ def features_factory():
) -> dict:
if use_videos:
camera_ft = {
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items()
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO}
for key, ft in camera_features.items()
}
else:
camera_ft = {key: {"dtype": "image", **ft} for key, ft in camera_features.items()}
camera_ft = {
key: {"dtype": "image", **ft} for key, ft in camera_features.items()
}
return {
**motor_features,
**camera_ft,
@@ -177,7 +190,9 @@ def episodes_factory(tasks_factory):
if total_episodes <= 0 or total_frames <= 0:
raise ValueError("num_episodes and total_length must be positive integers.")
if total_frames < total_episodes:
raise ValueError("total_length must be greater than or equal to num_episodes.")
raise ValueError(
"total_length must be greater than or equal to num_episodes."
)
if not tasks:
min_tasks = 2 if multi_task else 1
@@ -185,10 +200,14 @@ def episodes_factory(tasks_factory):
tasks = tasks_factory(total_tasks)
if total_episodes < len(tasks) and not multi_task:
raise ValueError("The number of tasks should be less than the number of episodes.")
raise ValueError(
"The number of tasks should be less than the number of episodes."
)
# Generate random lengths that sum up to total_length
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
lengths = np.random.multinomial(
total_frames, [1 / total_episodes] * total_episodes
).tolist()
tasks_list = [task_dict["task"] for task_dict in tasks]
num_tasks_available = len(tasks_list)
@@ -196,9 +215,13 @@ def episodes_factory(tasks_factory):
episodes_list = []
remaining_tasks = tasks_list.copy()
for ep_idx in range(total_episodes):
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
num_tasks_in_episode = (
random.randint(1, min(3, num_tasks_available)) if multi_task else 1
)
tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)))
episode_tasks = random.sample(
tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample))
)
if remaining_tasks:
for task in episode_tasks:
remaining_tasks.remove(task)
@@ -217,7 +240,9 @@ def episodes_factory(tasks_factory):
@pytest.fixture(scope="session")
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
def hf_dataset_factory(
features_factory, tasks_factory, episodes_factory, img_array_factory
):
def _create_hf_dataset(
features: dict | None = None,
tasks: list[dict] | None = None,
@@ -236,13 +261,22 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
episode_index_col = np.array([], dtype=np.int64)
task_index = np.array([], dtype=np.int64)
for ep_dict in episodes:
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
timestamp_col = np.concatenate(
(timestamp_col, np.arange(ep_dict["length"]) / fps)
)
frame_index_col = np.concatenate(
(frame_index_col, np.arange(ep_dict["length"], dtype=int))
)
episode_index_col = np.concatenate(
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
(
episode_index_col,
np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int),
)
)
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
task_index = np.concatenate(
(task_index, np.full(ep_dict["length"], ep_task_index, dtype=int))
)
index_col = np.arange(len(episode_index_col))
@@ -254,7 +288,9 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
for _ in range(len(index_col))
]
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"])
robot_cols[key] = np.random.random(
(len(index_col), ft["shape"][0])
).astype(ft["dtype"])
hf_features = get_hf_features_from_features(features)
dataset = datasets.Dataset.from_dict(
@@ -299,7 +335,9 @@ def lerobot_dataset_metadata_factory(
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
tasks=tasks,
)
mock_snapshot_download = mock_snapshot_download_factory(
@@ -316,10 +354,14 @@ def lerobot_dataset_metadata_factory(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
):
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version: version
mock_get_hub_safe_version_patch.side_effect = (
lambda repo_id, version: version
)
mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDatasetMetadata(repo_id=repo_id, root=root, local_files_only=local_files_only)
return LeRobotDatasetMetadata(
repo_id=repo_id, root=root, local_files_only=local_files_only
)
return _create_lerobot_dataset_metadata
@@ -350,7 +392,9 @@ def lerobot_dataset_factory(
) -> LeRobotDataset:
if not info:
info = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
total_episodes=total_episodes,
total_frames=total_frames,
total_tasks=total_tasks,
)
if not stats:
stats = stats_factory(features=info["features"])
@@ -364,7 +408,9 @@ def lerobot_dataset_factory(
multi_task=multi_task,
)
if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"])
hf_dataset = hf_dataset_factory(
tasks=tasks, episodes=episode_dicts, fps=info["fps"]
)
mock_snapshot_download = mock_snapshot_download_factory(
info=info,
@@ -383,7 +429,9 @@ def lerobot_dataset_factory(
local_files_only=kwargs.get("local_files_only", False),
)
with (
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata"
) as mock_metadata_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,

View File

@@ -7,7 +7,12 @@ import pyarrow.compute as pc
import pyarrow.parquet as pq
import pytest
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH
from lerobot.common.datasets.utils import (
EPISODES_PATH,
INFO_PATH,
STATS_PATH,
TASKS_PATH,
)
@pytest.fixture(scope="session")
@@ -69,7 +74,10 @@ def episode_path(episodes_factory):
@pytest.fixture(scope="session")
def single_episode_parquet_path(hf_dataset_factory, info_factory):
def _create_single_episode_parquet(
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
dir: Path,
ep_idx: int = 0,
hf_dataset: datasets.Dataset | None = None,
info: dict | None = None,
) -> Path:
if not info:
info = info_factory()

31
tests/fixtures/hub.py vendored
View File

@@ -4,7 +4,12 @@ import datasets
import pytest
from huggingface_hub.utils import filter_repo_objects
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH
from lerobot.common.datasets.utils import (
EPISODES_PATH,
INFO_PATH,
STATS_PATH,
TASKS_PATH,
)
from tests.fixtures.constants import LEROBOT_TEST_DIR
@@ -41,15 +46,21 @@ def mock_snapshot_download_factory(
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
tasks=tasks,
)
if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
hf_dataset = hf_dataset_factory(
tasks=tasks, episodes=episodes, fps=info["fps"]
)
def _extract_episode_index_from_path(fpath: str) -> int:
path = Path(fpath)
if path.suffix == ".parquet" and path.stem.startswith("episode_"):
episode_index = int(path.stem[len("episode_") :]) # 'episode_000000' -> 0
episode_index = int(
path.stem[len("episode_") :]
) # 'episode_000000' -> 0
return episode_index
else:
return None
@@ -74,12 +85,16 @@ def mock_snapshot_download_factory(
for episode_dict in episodes:
ep_idx = episode_dict["episode_index"]
ep_chunk = ep_idx // info["chunks_size"]
data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
data_path = info["data_path"].format(
episode_chunk=ep_chunk, episode_index=ep_idx
)
data_files.append(data_path)
all_files.extend(data_files)
allowed_files = filter_repo_objects(
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
all_files,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
# Create allowed files
@@ -87,7 +102,9 @@ def mock_snapshot_download_factory(
if rel_path.startswith("data/"):
episode_index = _extract_episode_index_from_path(rel_path)
if episode_index is not None:
_ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info)
_ = single_episode_parquet_path(
local_dir, episode_index, hf_dataset, info
)
if rel_path == INFO_PATH:
_ = info_path(local_dir, info)
elif rel_path == STATS_PATH: