Format file

This commit is contained in:
AdilZouitine
2025-05-07 10:26:18 +02:00
parent adbf8bb85e
commit b36ec31fea
13 changed files with 43 additions and 169 deletions

View File

@@ -23,11 +23,7 @@ import PIL.Image
import pytest
import torch
from lerobot.common.datasets.lerobot_dataset import (
CODEBASE_VERSION,
LeRobotDataset,
LeRobotDatasetMetadata,
)
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_FEATURES,
@@ -201,10 +197,7 @@ def tasks_factory():
def _create_tasks(total_tasks: int = 3) -> int:
tasks = {}
for task_index in range(total_tasks):
task_dict = {
"task_index": task_index,
"task": f"Perform action {task_index}.",
}
task_dict = {"task_index": task_index, "task": f"Perform action {task_index}."}
tasks[task_index] = task_dict
return tasks
@@ -282,10 +275,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
episode_index_col = np.concatenate(
(
episode_index_col,
np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int),
)
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
)
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
@@ -350,9 +340,7 @@ def lerobot_dataset_metadata_factory(
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
episodes = episodes_factory(
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
tasks=tasks,
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
)
mock_snapshot_download = mock_snapshot_download_factory(
@@ -404,9 +392,7 @@ def lerobot_dataset_factory(
) -> LeRobotDataset:
if not info:
info = info_factory(
total_episodes=total_episodes,
total_frames=total_frames,
total_tasks=total_tasks,
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
)
if not stats:
stats = stats_factory(features=info["features"])