[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
AdilZouitine
parent
76df8a31b3
commit
38f5fa4523
30
tests/fixtures/constants.py
vendored
30
tests/fixtures/constants.py
vendored
@@ -20,17 +20,39 @@ DUMMY_MOTOR_FEATURES = {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||
"names": [
|
||||
"shoulder_pan",
|
||||
"shoulder_lift",
|
||||
"elbow_flex",
|
||||
"wrist_flex",
|
||||
"wrist_roll",
|
||||
"gripper",
|
||||
],
|
||||
},
|
||||
"state": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||
"names": [
|
||||
"shoulder_pan",
|
||||
"shoulder_lift",
|
||||
"elbow_flex",
|
||||
"wrist_flex",
|
||||
"wrist_roll",
|
||||
"gripper",
|
||||
],
|
||||
},
|
||||
}
|
||||
DUMMY_CAMERA_FEATURES = {
|
||||
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
||||
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
||||
"laptop": {
|
||||
"shape": (480, 640, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
},
|
||||
"phone": {
|
||||
"shape": (480, 640, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
},
|
||||
}
|
||||
DEFAULT_FPS = 30
|
||||
DUMMY_VIDEO_INFO = {
|
||||
|
||||
74
tests/fixtures/dataset_factories.py
vendored
74
tests/fixtures/dataset_factories.py
vendored
@@ -23,7 +23,11 @@ import PIL.Image
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
CODEBASE_VERSION,
|
||||
LeRobotDataset,
|
||||
LeRobotDatasetMetadata,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_FEATURES,
|
||||
@@ -54,7 +58,9 @@ def get_task_index(task_dicts: dict, task: str) -> int:
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_tensor_factory():
|
||||
def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor:
|
||||
def _create_img_tensor(
|
||||
height=100, width=100, channels=3, dtype=torch.float32
|
||||
) -> torch.Tensor:
|
||||
return torch.rand((channels, height, width), dtype=dtype)
|
||||
|
||||
return _create_img_tensor
|
||||
@@ -62,10 +68,14 @@ def img_tensor_factory():
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_array_factory():
|
||||
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray:
|
||||
def _create_img_array(
|
||||
height=100, width=100, channels=3, dtype=np.uint8
|
||||
) -> np.ndarray:
|
||||
if np.issubdtype(dtype, np.unsignedinteger):
|
||||
# Int array in [0, 255] range
|
||||
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
|
||||
img_array = np.random.randint(
|
||||
0, 256, size=(height, width, channels), dtype=dtype
|
||||
)
|
||||
elif np.issubdtype(dtype, np.floating):
|
||||
# Float array in [0, 1] range
|
||||
img_array = np.random.rand(height, width, channels).astype(dtype)
|
||||
@@ -94,10 +104,13 @@ def features_factory():
|
||||
) -> dict:
|
||||
if use_videos:
|
||||
camera_ft = {
|
||||
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items()
|
||||
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO}
|
||||
for key, ft in camera_features.items()
|
||||
}
|
||||
else:
|
||||
camera_ft = {key: {"dtype": "image", **ft} for key, ft in camera_features.items()}
|
||||
camera_ft = {
|
||||
key: {"dtype": "image", **ft} for key, ft in camera_features.items()
|
||||
}
|
||||
return {
|
||||
**motor_features,
|
||||
**camera_ft,
|
||||
@@ -215,7 +228,9 @@ def episodes_factory(tasks_factory):
|
||||
if total_episodes <= 0 or total_frames <= 0:
|
||||
raise ValueError("num_episodes and total_length must be positive integers.")
|
||||
if total_frames < total_episodes:
|
||||
raise ValueError("total_length must be greater than or equal to num_episodes.")
|
||||
raise ValueError(
|
||||
"total_length must be greater than or equal to num_episodes."
|
||||
)
|
||||
|
||||
if not tasks:
|
||||
min_tasks = 2 if multi_task else 1
|
||||
@@ -223,10 +238,14 @@ def episodes_factory(tasks_factory):
|
||||
tasks = tasks_factory(total_tasks)
|
||||
|
||||
if total_episodes < len(tasks) and not multi_task:
|
||||
raise ValueError("The number of tasks should be less than the number of episodes.")
|
||||
raise ValueError(
|
||||
"The number of tasks should be less than the number of episodes."
|
||||
)
|
||||
|
||||
# Generate random lengths that sum up to total_length
|
||||
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
|
||||
lengths = np.random.multinomial(
|
||||
total_frames, [1 / total_episodes] * total_episodes
|
||||
).tolist()
|
||||
|
||||
tasks_list = [task_dict["task"] for task_dict in tasks.values()]
|
||||
num_tasks_available = len(tasks_list)
|
||||
@@ -234,9 +253,13 @@ def episodes_factory(tasks_factory):
|
||||
episodes = {}
|
||||
remaining_tasks = tasks_list.copy()
|
||||
for ep_idx in range(total_episodes):
|
||||
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
|
||||
num_tasks_in_episode = (
|
||||
random.randint(1, min(3, num_tasks_available)) if multi_task else 1
|
||||
)
|
||||
tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list
|
||||
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)))
|
||||
episode_tasks = random.sample(
|
||||
tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample))
|
||||
)
|
||||
if remaining_tasks:
|
||||
for task in episode_tasks:
|
||||
remaining_tasks.remove(task)
|
||||
@@ -253,7 +276,9 @@ def episodes_factory(tasks_factory):
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
|
||||
def hf_dataset_factory(
|
||||
features_factory, tasks_factory, episodes_factory, img_array_factory
|
||||
):
|
||||
def _create_hf_dataset(
|
||||
features: dict | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
@@ -275,10 +300,15 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
||||
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
|
||||
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
|
||||
episode_index_col = np.concatenate(
|
||||
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
|
||||
(
|
||||
episode_index_col,
|
||||
np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int),
|
||||
)
|
||||
)
|
||||
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
|
||||
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
|
||||
task_index = np.concatenate(
|
||||
(task_index, np.full(ep_dict["length"], ep_task_index, dtype=int))
|
||||
)
|
||||
|
||||
index_col = np.arange(len(episode_index_col))
|
||||
|
||||
@@ -290,7 +320,9 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
||||
for _ in range(len(index_col))
|
||||
]
|
||||
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
|
||||
robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"])
|
||||
robot_cols[key] = np.random.random(
|
||||
(len(index_col), ft["shape"][0])
|
||||
).astype(ft["dtype"])
|
||||
|
||||
hf_features = get_hf_features_from_features(features)
|
||||
dataset = datasets.Dataset.from_dict(
|
||||
@@ -340,7 +372,9 @@ def lerobot_dataset_metadata_factory(
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episodes:
|
||||
episodes = episodes_factory(
|
||||
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
tasks=tasks,
|
||||
)
|
||||
|
||||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
@@ -392,7 +426,9 @@ def lerobot_dataset_factory(
|
||||
) -> LeRobotDataset:
|
||||
if not info:
|
||||
info = info_factory(
|
||||
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
total_tasks=total_tasks,
|
||||
)
|
||||
if not stats:
|
||||
stats = stats_factory(features=info["features"])
|
||||
@@ -408,7 +444,9 @@ def lerobot_dataset_factory(
|
||||
multi_task=multi_task,
|
||||
)
|
||||
if not hf_dataset:
|
||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"])
|
||||
hf_dataset = hf_dataset_factory(
|
||||
tasks=tasks, episodes=episode_dicts, fps=info["fps"]
|
||||
)
|
||||
|
||||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
info=info,
|
||||
|
||||
5
tests/fixtures/files.py
vendored
5
tests/fixtures/files.py
vendored
@@ -102,7 +102,10 @@ def episode_path(episodes_factory):
|
||||
@pytest.fixture(scope="session")
|
||||
def single_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||
def _create_single_episode_parquet(
|
||||
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||
dir: Path,
|
||||
ep_idx: int = 0,
|
||||
hf_dataset: datasets.Dataset | None = None,
|
||||
info: dict | None = None,
|
||||
) -> Path:
|
||||
if not info:
|
||||
info = info_factory()
|
||||
|
||||
24
tests/fixtures/hub.py
vendored
24
tests/fixtures/hub.py
vendored
@@ -67,15 +67,21 @@ def mock_snapshot_download_factory(
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episodes:
|
||||
episodes = episodes_factory(
|
||||
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
tasks=tasks,
|
||||
)
|
||||
if not hf_dataset:
|
||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
|
||||
hf_dataset = hf_dataset_factory(
|
||||
tasks=tasks, episodes=episodes, fps=info["fps"]
|
||||
)
|
||||
|
||||
def _extract_episode_index_from_path(fpath: str) -> int:
|
||||
path = Path(fpath)
|
||||
if path.suffix == ".parquet" and path.stem.startswith("episode_"):
|
||||
episode_index = int(path.stem[len("episode_") :]) # 'episode_000000' -> 0
|
||||
episode_index = int(
|
||||
path.stem[len("episode_") :]
|
||||
) # 'episode_000000' -> 0
|
||||
return episode_index
|
||||
else:
|
||||
return None
|
||||
@@ -100,12 +106,16 @@ def mock_snapshot_download_factory(
|
||||
for episode_dict in episodes.values():
|
||||
ep_idx = episode_dict["episode_index"]
|
||||
ep_chunk = ep_idx // info["chunks_size"]
|
||||
data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
data_path = info["data_path"].format(
|
||||
episode_chunk=ep_chunk, episode_index=ep_idx
|
||||
)
|
||||
data_files.append(data_path)
|
||||
all_files.extend(data_files)
|
||||
|
||||
allowed_files = filter_repo_objects(
|
||||
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
|
||||
all_files,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
|
||||
# Create allowed files
|
||||
@@ -113,7 +123,9 @@ def mock_snapshot_download_factory(
|
||||
if rel_path.startswith("data/"):
|
||||
episode_index = _extract_episode_index_from_path(rel_path)
|
||||
if episode_index is not None:
|
||||
_ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info)
|
||||
_ = single_episode_parquet_path(
|
||||
local_dir, episode_index, hf_dataset, info
|
||||
)
|
||||
if rel_path == INFO_PATH:
|
||||
_ = info_path(local_dir, info)
|
||||
elif rel_path == STATS_PATH:
|
||||
|
||||
Reference in New Issue
Block a user