Format file

This commit is contained in:
AdilZouitine
2025-05-07 10:26:18 +02:00
parent adbf8bb85e
commit b36ec31fea
13 changed files with 43 additions and 169 deletions

View File

@@ -20,39 +20,17 @@ DUMMY_MOTOR_FEATURES = {
"action": {
"dtype": "float32",
"shape": (6,),
"names": [
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
],
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
},
"state": {
"dtype": "float32",
"shape": (6,),
"names": [
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
],
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
},
}
DUMMY_CAMERA_FEATURES = {
"laptop": {
"shape": (480, 640, 3),
"names": ["height", "width", "channels"],
"info": None,
},
"phone": {
"shape": (480, 640, 3),
"names": ["height", "width", "channels"],
"info": None,
},
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
}
DEFAULT_FPS = 30
DUMMY_VIDEO_INFO = {

View File

@@ -23,11 +23,7 @@ import PIL.Image
import pytest
import torch
from lerobot.common.datasets.lerobot_dataset import (
CODEBASE_VERSION,
LeRobotDataset,
LeRobotDatasetMetadata,
)
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_FEATURES,
@@ -201,10 +197,7 @@ def tasks_factory():
def _create_tasks(total_tasks: int = 3) -> int:
tasks = {}
for task_index in range(total_tasks):
task_dict = {
"task_index": task_index,
"task": f"Perform action {task_index}.",
}
task_dict = {"task_index": task_index, "task": f"Perform action {task_index}."}
tasks[task_index] = task_dict
return tasks
@@ -282,10 +275,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
episode_index_col = np.concatenate(
(
episode_index_col,
np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int),
)
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
)
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
@@ -350,9 +340,7 @@ def lerobot_dataset_metadata_factory(
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
episodes = episodes_factory(
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
tasks=tasks,
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
)
mock_snapshot_download = mock_snapshot_download_factory(
@@ -404,9 +392,7 @@ def lerobot_dataset_factory(
) -> LeRobotDataset:
if not info:
info = info_factory(
total_episodes=total_episodes,
total_frames=total_frames,
total_tasks=total_tasks,
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
)
if not stats:
stats = stats_factory(features=info["features"])

View File

@@ -102,10 +102,7 @@ def episode_path(episodes_factory):
@pytest.fixture(scope="session")
def single_episode_parquet_path(hf_dataset_factory, info_factory):
def _create_single_episode_parquet(
dir: Path,
ep_idx: int = 0,
hf_dataset: datasets.Dataset | None = None,
info: dict | None = None,
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path:
if not info:
info = info_factory()

16
tests/fixtures/hub.py vendored
View File

@@ -67,9 +67,7 @@ def mock_snapshot_download_factory(
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
episodes = episodes_factory(
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
tasks=tasks,
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
)
if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
@@ -95,13 +93,7 @@ def mock_snapshot_download_factory(
# List all possible files
all_files = []
meta_files = [
INFO_PATH,
STATS_PATH,
EPISODES_STATS_PATH,
TASKS_PATH,
EPISODES_PATH,
]
meta_files = [INFO_PATH, STATS_PATH, EPISODES_STATS_PATH, TASKS_PATH, EPISODES_PATH]
all_files.extend(meta_files)
data_files = []
@@ -113,9 +105,7 @@ def mock_snapshot_download_factory(
all_files.extend(data_files)
allowed_files = filter_repo_objects(
all_files,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
)
# Create allowed files