forked from tangger/lerobot
Format file
This commit is contained in:
24
tests/fixtures/dataset_factories.py
vendored
24
tests/fixtures/dataset_factories.py
vendored
@@ -23,11 +23,7 @@ import PIL.Image
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
CODEBASE_VERSION,
|
||||
LeRobotDataset,
|
||||
LeRobotDatasetMetadata,
|
||||
)
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_FEATURES,
|
||||
@@ -201,10 +197,7 @@ def tasks_factory():
|
||||
def _create_tasks(total_tasks: int = 3) -> int:
|
||||
tasks = {}
|
||||
for task_index in range(total_tasks):
|
||||
task_dict = {
|
||||
"task_index": task_index,
|
||||
"task": f"Perform action {task_index}.",
|
||||
}
|
||||
task_dict = {"task_index": task_index, "task": f"Perform action {task_index}."}
|
||||
tasks[task_index] = task_dict
|
||||
return tasks
|
||||
|
||||
@@ -282,10 +275,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
||||
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
|
||||
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
|
||||
episode_index_col = np.concatenate(
|
||||
(
|
||||
episode_index_col,
|
||||
np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int),
|
||||
)
|
||||
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
|
||||
)
|
||||
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
|
||||
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
|
||||
@@ -350,9 +340,7 @@ def lerobot_dataset_metadata_factory(
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episodes:
|
||||
episodes = episodes_factory(
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
tasks=tasks,
|
||||
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
|
||||
)
|
||||
|
||||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
@@ -404,9 +392,7 @@ def lerobot_dataset_factory(
|
||||
) -> LeRobotDataset:
|
||||
if not info:
|
||||
info = info_factory(
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
total_tasks=total_tasks,
|
||||
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
|
||||
)
|
||||
if not stats:
|
||||
stats = stats_factory(features=info["features"])
|
||||
|
||||
Reference in New Issue
Block a user