fix(tests): remove lint warnings/errors
This commit is contained in:
@@ -108,7 +108,7 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu
|
|||||||
|
|
||||||
|
|
||||||
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
||||||
for i in enumerate(stats_list):
|
for i in range(len(stats_list)):
|
||||||
for fkey in stats_list[i]:
|
for fkey in stats_list[i]:
|
||||||
for k, v in stats_list[i][fkey].items():
|
for k, v in stats_list[i][fkey].items():
|
||||||
if not isinstance(v, np.ndarray):
|
if not isinstance(v, np.ndarray):
|
||||||
|
|||||||
68
tests/fixtures/dataset_factories.py
vendored
68
tests/fixtures/dataset_factories.py
vendored
@@ -52,16 +52,16 @@ def get_task_index(task_dicts: dict, task: str) -> int:
|
|||||||
return task_to_task_index[task]
|
return task_to_task_index[task]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(name="img_tensor_factory", scope="session")
|
||||||
def img_tensor_factory():
|
def fixture_img_tensor_factory():
|
||||||
def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor:
|
def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor:
|
||||||
return torch.rand((channels, height, width), dtype=dtype)
|
return torch.rand((channels, height, width), dtype=dtype)
|
||||||
|
|
||||||
return _create_img_tensor
|
return _create_img_tensor
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(name="img_array_factory", scope="session")
|
||||||
def img_array_factory():
|
def fixture_img_array_factory():
|
||||||
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray:
|
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray:
|
||||||
if np.issubdtype(dtype, np.unsignedinteger):
|
if np.issubdtype(dtype, np.unsignedinteger):
|
||||||
# Int array in [0, 255] range
|
# Int array in [0, 255] range
|
||||||
@@ -76,8 +76,8 @@ def img_array_factory():
|
|||||||
return _create_img_array
|
return _create_img_array
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(name="img_factory", scope="session")
|
||||||
def img_factory(img_array_factory):
|
def fixture_img_factory(img_array_factory):
|
||||||
def _create_img(height=100, width=100) -> PIL.Image.Image:
|
def _create_img(height=100, width=100) -> PIL.Image.Image:
|
||||||
img_array = img_array_factory(height=height, width=width)
|
img_array = img_array_factory(height=height, width=width)
|
||||||
return PIL.Image.fromarray(img_array)
|
return PIL.Image.fromarray(img_array)
|
||||||
@@ -85,13 +85,17 @@ def img_factory(img_array_factory):
|
|||||||
return _create_img
|
return _create_img
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(name="features_factory", scope="session")
|
||||||
def features_factory():
|
def fixture_features_factory():
|
||||||
def _create_features(
|
def _create_features(
|
||||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
motor_features: dict | None = None,
|
||||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
camera_features: dict | None = None,
|
||||||
use_videos: bool = True,
|
use_videos: bool = True,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
if motor_features is None:
|
||||||
|
motor_features = DUMMY_MOTOR_FEATURES
|
||||||
|
if camera_features is None:
|
||||||
|
camera_features = DUMMY_CAMERA_FEATURES
|
||||||
if use_videos:
|
if use_videos:
|
||||||
camera_ft = {
|
camera_ft = {
|
||||||
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items()
|
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items()
|
||||||
@@ -107,8 +111,8 @@ def features_factory():
|
|||||||
return _create_features
|
return _create_features
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(name="info_factory", scope="session")
|
||||||
def info_factory(features_factory):
|
def fixture_info_factory(features_factory):
|
||||||
def _create_info(
|
def _create_info(
|
||||||
codebase_version: str = CODEBASE_VERSION,
|
codebase_version: str = CODEBASE_VERSION,
|
||||||
fps: int = DEFAULT_FPS,
|
fps: int = DEFAULT_FPS,
|
||||||
@@ -121,10 +125,14 @@ def info_factory(features_factory):
|
|||||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||||
data_path: str = DEFAULT_PARQUET_PATH,
|
data_path: str = DEFAULT_PARQUET_PATH,
|
||||||
video_path: str = DEFAULT_VIDEO_PATH,
|
video_path: str = DEFAULT_VIDEO_PATH,
|
||||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
motor_features: dict | None = None,
|
||||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
camera_features: dict | None = None,
|
||||||
use_videos: bool = True,
|
use_videos: bool = True,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
if motor_features is None:
|
||||||
|
motor_features = DUMMY_MOTOR_FEATURES
|
||||||
|
if camera_features is None:
|
||||||
|
camera_features = DUMMY_CAMERA_FEATURES
|
||||||
features = features_factory(motor_features, camera_features, use_videos)
|
features = features_factory(motor_features, camera_features, use_videos)
|
||||||
return {
|
return {
|
||||||
"codebase_version": codebase_version,
|
"codebase_version": codebase_version,
|
||||||
@@ -145,8 +153,8 @@ def info_factory(features_factory):
|
|||||||
return _create_info
|
return _create_info
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(name="stats_factory", scope="session")
|
||||||
def stats_factory():
|
def fixture_stats_factory():
|
||||||
def _create_stats(
|
def _create_stats(
|
||||||
features: dict[str] | None = None,
|
features: dict[str] | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
@@ -175,8 +183,8 @@ def stats_factory():
|
|||||||
return _create_stats
|
return _create_stats
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(name="episodes_stats_factory", scope="session")
|
||||||
def episodes_stats_factory(stats_factory):
|
def fixture_episodes_stats_factory(stats_factory):
|
||||||
def _create_episodes_stats(
|
def _create_episodes_stats(
|
||||||
features: dict[str],
|
features: dict[str],
|
||||||
total_episodes: int = 3,
|
total_episodes: int = 3,
|
||||||
@@ -192,8 +200,8 @@ def episodes_stats_factory(stats_factory):
|
|||||||
return _create_episodes_stats
|
return _create_episodes_stats
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(name="tasks_factory", scope="session")
|
||||||
def tasks_factory():
|
def fixture_tasks_factory():
|
||||||
def _create_tasks(total_tasks: int = 3) -> int:
|
def _create_tasks(total_tasks: int = 3) -> int:
|
||||||
tasks = {}
|
tasks = {}
|
||||||
for task_index in range(total_tasks):
|
for task_index in range(total_tasks):
|
||||||
@@ -204,8 +212,8 @@ def tasks_factory():
|
|||||||
return _create_tasks
|
return _create_tasks
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(name="episodes_factory", scope="session")
|
||||||
def episodes_factory(tasks_factory):
|
def fixture_episodes_factory(tasks_factory):
|
||||||
def _create_episodes(
|
def _create_episodes(
|
||||||
total_episodes: int = 3,
|
total_episodes: int = 3,
|
||||||
total_frames: int = 400,
|
total_frames: int = 400,
|
||||||
@@ -252,8 +260,8 @@ def episodes_factory(tasks_factory):
|
|||||||
return _create_episodes
|
return _create_episodes
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(name="hf_dataset_factory", scope="session")
|
||||||
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
|
def fixture_hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
|
||||||
def _create_hf_dataset(
|
def _create_hf_dataset(
|
||||||
features: dict | None = None,
|
features: dict | None = None,
|
||||||
tasks: list[dict] | None = None,
|
tasks: list[dict] | None = None,
|
||||||
@@ -310,8 +318,8 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
|||||||
return _create_hf_dataset
|
return _create_hf_dataset
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(name="lerobot_dataset_metadata_factory", scope="session")
|
||||||
def lerobot_dataset_metadata_factory(
|
def fixture_lerobot_dataset_metadata_factory(
|
||||||
info_factory,
|
info_factory,
|
||||||
stats_factory,
|
stats_factory,
|
||||||
episodes_stats_factory,
|
episodes_stats_factory,
|
||||||
@@ -364,8 +372,8 @@ def lerobot_dataset_metadata_factory(
|
|||||||
return _create_lerobot_dataset_metadata
|
return _create_lerobot_dataset_metadata
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(name="lerobot_dataset_factory", scope="session")
|
||||||
def lerobot_dataset_factory(
|
def fixture_lerobot_dataset_factory(
|
||||||
info_factory,
|
info_factory,
|
||||||
stats_factory,
|
stats_factory,
|
||||||
episodes_stats_factory,
|
episodes_stats_factory,
|
||||||
@@ -443,6 +451,6 @@ def lerobot_dataset_factory(
|
|||||||
return _create_lerobot_dataset
|
return _create_lerobot_dataset
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(name="empty_lerobot_dataset_factory", scope="session")
|
||||||
def empty_lerobot_dataset_factory() -> LeRobotDatasetFactory:
|
def fixture_empty_lerobot_dataset_factory() -> LeRobotDatasetFactory:
|
||||||
return partial(LeRobotDataset.create, repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS)
|
return partial(LeRobotDataset.create, repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS)
|
||||||
|
|||||||
34
tests/fixtures/files.py
vendored
34
tests/fixtures/files.py
vendored
@@ -31,12 +31,12 @@ from lerobot.common.datasets.utils import (
|
|||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def info_path(info_factory):
|
def info_path(info_factory):
|
||||||
def _create_info_json_file(dir: Path, info: dict | None = None) -> Path:
|
def _create_info_json_file(input_dir: Path, info: dict | None = None) -> Path:
|
||||||
if not info:
|
if not info:
|
||||||
info = info_factory()
|
info = info_factory()
|
||||||
fpath = dir / INFO_PATH
|
fpath = input_dir / INFO_PATH
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
with open(fpath, "w") as f:
|
with open(fpath, "w", encoding="utf-8") as f:
|
||||||
json.dump(info, f, indent=4, ensure_ascii=False)
|
json.dump(info, f, indent=4, ensure_ascii=False)
|
||||||
return fpath
|
return fpath
|
||||||
|
|
||||||
@@ -45,12 +45,12 @@ def info_path(info_factory):
|
|||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def stats_path(stats_factory):
|
def stats_path(stats_factory):
|
||||||
def _create_stats_json_file(dir: Path, stats: dict | None = None) -> Path:
|
def _create_stats_json_file(input_dir: Path, stats: dict | None = None) -> Path:
|
||||||
if not stats:
|
if not stats:
|
||||||
stats = stats_factory()
|
stats = stats_factory()
|
||||||
fpath = dir / STATS_PATH
|
fpath = input_dir / STATS_PATH
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
with open(fpath, "w") as f:
|
with open(fpath, "w", encoding="utf-8") as f:
|
||||||
json.dump(stats, f, indent=4, ensure_ascii=False)
|
json.dump(stats, f, indent=4, ensure_ascii=False)
|
||||||
return fpath
|
return fpath
|
||||||
|
|
||||||
@@ -59,10 +59,10 @@ def stats_path(stats_factory):
|
|||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def episodes_stats_path(episodes_stats_factory):
|
def episodes_stats_path(episodes_stats_factory):
|
||||||
def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path:
|
def _create_episodes_stats_jsonl_file(input_dir: Path, episodes_stats: list[dict] | None = None) -> Path:
|
||||||
if not episodes_stats:
|
if not episodes_stats:
|
||||||
episodes_stats = episodes_stats_factory()
|
episodes_stats = episodes_stats_factory()
|
||||||
fpath = dir / EPISODES_STATS_PATH
|
fpath = input_dir / EPISODES_STATS_PATH
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
with jsonlines.open(fpath, "w") as writer:
|
with jsonlines.open(fpath, "w") as writer:
|
||||||
writer.write_all(episodes_stats.values())
|
writer.write_all(episodes_stats.values())
|
||||||
@@ -73,10 +73,10 @@ def episodes_stats_path(episodes_stats_factory):
|
|||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def tasks_path(tasks_factory):
|
def tasks_path(tasks_factory):
|
||||||
def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path:
|
def _create_tasks_jsonl_file(input_dir: Path, tasks: list | None = None) -> Path:
|
||||||
if not tasks:
|
if not tasks:
|
||||||
tasks = tasks_factory()
|
tasks = tasks_factory()
|
||||||
fpath = dir / TASKS_PATH
|
fpath = input_dir / TASKS_PATH
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
with jsonlines.open(fpath, "w") as writer:
|
with jsonlines.open(fpath, "w") as writer:
|
||||||
writer.write_all(tasks.values())
|
writer.write_all(tasks.values())
|
||||||
@@ -87,10 +87,10 @@ def tasks_path(tasks_factory):
|
|||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def episode_path(episodes_factory):
|
def episode_path(episodes_factory):
|
||||||
def _create_episodes_jsonl_file(dir: Path, episodes: list | None = None) -> Path:
|
def _create_episodes_jsonl_file(input_dir: Path, episodes: list | None = None) -> Path:
|
||||||
if not episodes:
|
if not episodes:
|
||||||
episodes = episodes_factory()
|
episodes = episodes_factory()
|
||||||
fpath = dir / EPISODES_PATH
|
fpath = input_dir / EPISODES_PATH
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
with jsonlines.open(fpath, "w") as writer:
|
with jsonlines.open(fpath, "w") as writer:
|
||||||
writer.write_all(episodes.values())
|
writer.write_all(episodes.values())
|
||||||
@@ -102,7 +102,7 @@ def episode_path(episodes_factory):
|
|||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def single_episode_parquet_path(hf_dataset_factory, info_factory):
|
def single_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||||
def _create_single_episode_parquet(
|
def _create_single_episode_parquet(
|
||||||
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
input_dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||||
) -> Path:
|
) -> Path:
|
||||||
if not info:
|
if not info:
|
||||||
info = info_factory()
|
info = info_factory()
|
||||||
@@ -112,7 +112,7 @@ def single_episode_parquet_path(hf_dataset_factory, info_factory):
|
|||||||
data_path = info["data_path"]
|
data_path = info["data_path"]
|
||||||
chunks_size = info["chunks_size"]
|
chunks_size = info["chunks_size"]
|
||||||
ep_chunk = ep_idx // chunks_size
|
ep_chunk = ep_idx // chunks_size
|
||||||
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
fpath = input_dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
table = hf_dataset.data.table
|
table = hf_dataset.data.table
|
||||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||||
@@ -125,7 +125,7 @@ def single_episode_parquet_path(hf_dataset_factory, info_factory):
|
|||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def multi_episode_parquet_path(hf_dataset_factory, info_factory):
|
def multi_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||||
def _create_multi_episode_parquet(
|
def _create_multi_episode_parquet(
|
||||||
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
input_dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||||
) -> Path:
|
) -> Path:
|
||||||
if not info:
|
if not info:
|
||||||
info = info_factory()
|
info = info_factory()
|
||||||
@@ -137,11 +137,11 @@ def multi_episode_parquet_path(hf_dataset_factory, info_factory):
|
|||||||
total_episodes = info["total_episodes"]
|
total_episodes = info["total_episodes"]
|
||||||
for ep_idx in range(total_episodes):
|
for ep_idx in range(total_episodes):
|
||||||
ep_chunk = ep_idx // chunks_size
|
ep_chunk = ep_idx // chunks_size
|
||||||
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
fpath = input_dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
table = hf_dataset.data.table
|
table = hf_dataset.data.table
|
||||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||||
pq.write_table(ep_table, fpath)
|
pq.write_table(ep_table, fpath)
|
||||||
return dir / "data"
|
return input_dir / "data"
|
||||||
|
|
||||||
return _create_multi_episode_parquet
|
return _create_multi_episode_parquet
|
||||||
|
|||||||
6
tests/fixtures/hub.py
vendored
6
tests/fixtures/hub.py
vendored
@@ -81,12 +81,12 @@ def mock_snapshot_download_factory(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def _mock_snapshot_download(
|
def _mock_snapshot_download(
|
||||||
repo_id: str,
|
_repo_id: str,
|
||||||
|
*_args,
|
||||||
local_dir: str | Path | None = None,
|
local_dir: str | Path | None = None,
|
||||||
allow_patterns: str | list[str] | None = None,
|
allow_patterns: str | list[str] | None = None,
|
||||||
ignore_patterns: str | list[str] | None = None,
|
ignore_patterns: str | list[str] | None = None,
|
||||||
*args,
|
**_kwargs,
|
||||||
**kwargs,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
if not local_dir:
|
if not local_dir:
|
||||||
local_dir = LEROBOT_TEST_DIR
|
local_dir = LEROBOT_TEST_DIR
|
||||||
|
|||||||
12
tests/fixtures/optimizers.py
vendored
12
tests/fixtures/optimizers.py
vendored
@@ -18,13 +18,13 @@ from lerobot.common.optim.optimizers import AdamConfig
|
|||||||
from lerobot.common.optim.schedulers import VQBeTSchedulerConfig
|
from lerobot.common.optim.schedulers import VQBeTSchedulerConfig
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(name="model_params")
|
||||||
def model_params():
|
def fixture_model_params():
|
||||||
return [torch.nn.Parameter(torch.randn(10, 10))]
|
return [torch.nn.Parameter(torch.randn(10, 10))]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(name="optimizer")
|
||||||
def optimizer(model_params):
|
def fixture_optimizer(model_params):
|
||||||
optimizer = AdamConfig().build(model_params)
|
optimizer = AdamConfig().build(model_params)
|
||||||
# Dummy step to populate state
|
# Dummy step to populate state
|
||||||
loss = sum(param.sum() for param in model_params)
|
loss = sum(param.sum() for param in model_params)
|
||||||
@@ -33,7 +33,7 @@ def optimizer(model_params):
|
|||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(name="scheduler")
|
||||||
def scheduler(optimizer):
|
def fixture_scheduler(optimizer):
|
||||||
config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5)
|
config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5)
|
||||||
return config.build(optimizer, num_training_steps=100)
|
return config.build(optimizer, num_training_steps=100)
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ from lerobot.common.datasets.utils import create_lerobot_dataset_card
|
|||||||
def test_default_parameters():
|
def test_default_parameters():
|
||||||
card = create_lerobot_dataset_card()
|
card = create_lerobot_dataset_card()
|
||||||
assert isinstance(card, DatasetCard)
|
assert isinstance(card, DatasetCard)
|
||||||
|
# TODO(Steven): Base class CardDate should have 'tags' as a member if we want RepoCard to hold a reference to this abstraction
|
||||||
|
# card.data gives a CardDate type, implementations of this class do have 'tags' but the base class doesn't
|
||||||
assert card.data.tags == ["LeRobot"]
|
assert card.data.tags == ["LeRobot"]
|
||||||
assert card.data.task_categories == ["robotics"]
|
assert card.data.task_categories == ["robotics"]
|
||||||
assert card.data.configs == [
|
assert card.data.configs == [
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ def rotate(color_image, rotation):
|
|||||||
|
|
||||||
|
|
||||||
class VideoCapture:
|
class VideoCapture:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *_args, **_kwargs):
|
||||||
self._mock_dict = {
|
self._mock_dict = {
|
||||||
CAP_PROP_FPS: 30,
|
CAP_PROP_FPS: 30,
|
||||||
CAP_PROP_FRAME_WIDTH: 640,
|
CAP_PROP_FRAME_WIDTH: 640,
|
||||||
|
|||||||
@@ -24,10 +24,9 @@ DEFAULT_BAUDRATE = 9_600
|
|||||||
COMM_SUCCESS = 0 # tx or rx packet communication success
|
COMM_SUCCESS = 0 # tx or rx packet communication success
|
||||||
|
|
||||||
|
|
||||||
def convert_to_bytes(value, bytes):
|
def convert_to_bytes(value, _byte):
|
||||||
# TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform
|
# TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform
|
||||||
# `convert_bytes_to_value`
|
# `convert_bytes_to_value`
|
||||||
del bytes # unused
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
@@ -74,7 +73,7 @@ class PacketHandler:
|
|||||||
|
|
||||||
|
|
||||||
class GroupSyncRead:
|
class GroupSyncRead:
|
||||||
def __init__(self, port_handler, packet_handler, address, bytes):
|
def __init__(self, _port_handler, packet_handler, _address, _byte):
|
||||||
self.packet_handler = packet_handler
|
self.packet_handler = packet_handler
|
||||||
|
|
||||||
def addParam(self, motor_index): # noqa: N802
|
def addParam(self, motor_index): # noqa: N802
|
||||||
@@ -85,12 +84,12 @@ class GroupSyncRead:
|
|||||||
def txRxPacket(self): # noqa: N802
|
def txRxPacket(self): # noqa: N802
|
||||||
return COMM_SUCCESS
|
return COMM_SUCCESS
|
||||||
|
|
||||||
def getData(self, index, address, bytes): # noqa: N802
|
def getData(self, index, address, _byte): # noqa: N802
|
||||||
return self.packet_handler.data[index][address]
|
return self.packet_handler.data[index][address]
|
||||||
|
|
||||||
|
|
||||||
class GroupSyncWrite:
|
class GroupSyncWrite:
|
||||||
def __init__(self, port_handler, packet_handler, address, bytes):
|
def __init__(self, _port_handler, packet_handler, address, _byte):
|
||||||
self.packet_handler = packet_handler
|
self.packet_handler = packet_handler
|
||||||
self.address = address
|
self.address = address
|
||||||
|
|
||||||
|
|||||||
@@ -27,6 +27,13 @@ class format(enum.Enum): # noqa: N801
|
|||||||
|
|
||||||
|
|
||||||
class config: # noqa: N801
|
class config: # noqa: N801
|
||||||
|
device_enabled = None
|
||||||
|
stream_type = None
|
||||||
|
width = None
|
||||||
|
height = None
|
||||||
|
color_format = None
|
||||||
|
fps = None
|
||||||
|
|
||||||
def enable_device(self, device_id: str):
|
def enable_device(self, device_id: str):
|
||||||
self.device_enabled = device_id
|
self.device_enabled = device_id
|
||||||
|
|
||||||
@@ -125,8 +132,7 @@ class RSDevice:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_info(self, camera_info) -> str:
|
def get_info(self, _camera_info) -> str:
|
||||||
del camera_info # unused
|
|
||||||
# return fake serial number
|
# return fake serial number
|
||||||
return "123456789"
|
return "123456789"
|
||||||
|
|
||||||
@@ -145,4 +151,3 @@ class camera_info: # noqa: N801
|
|||||||
|
|
||||||
def __init__(self, serial_number):
|
def __init__(self, serial_number):
|
||||||
del serial_number
|
del serial_number
|
||||||
pass
|
|
||||||
|
|||||||
@@ -24,10 +24,10 @@ DEFAULT_BAUDRATE = 1_000_000
|
|||||||
COMM_SUCCESS = 0 # tx or rx packet communication success
|
COMM_SUCCESS = 0 # tx or rx packet communication success
|
||||||
|
|
||||||
|
|
||||||
def convert_to_bytes(value, bytes):
|
def convert_to_bytes(value, byte):
|
||||||
# TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform
|
# TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform
|
||||||
# `convert_bytes_to_value`
|
# `convert_bytes_to_value`
|
||||||
del bytes # unused
|
del byte # unused
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
@@ -85,7 +85,7 @@ class PacketHandler:
|
|||||||
|
|
||||||
|
|
||||||
class GroupSyncRead:
|
class GroupSyncRead:
|
||||||
def __init__(self, port_handler, packet_handler, address, bytes):
|
def __init__(self, _port_handler, packet_handler, _address, _byte):
|
||||||
self.packet_handler = packet_handler
|
self.packet_handler = packet_handler
|
||||||
|
|
||||||
def addParam(self, motor_index): # noqa: N802
|
def addParam(self, motor_index): # noqa: N802
|
||||||
@@ -96,12 +96,12 @@ class GroupSyncRead:
|
|||||||
def txRxPacket(self): # noqa: N802
|
def txRxPacket(self): # noqa: N802
|
||||||
return COMM_SUCCESS
|
return COMM_SUCCESS
|
||||||
|
|
||||||
def getData(self, index, address, bytes): # noqa: N802
|
def getData(self, index, address, _byte): # noqa: N802
|
||||||
return self.packet_handler.data[index][address]
|
return self.packet_handler.data[index][address]
|
||||||
|
|
||||||
|
|
||||||
class GroupSyncWrite:
|
class GroupSyncWrite:
|
||||||
def __init__(self, port_handler, packet_handler, address, bytes):
|
def __init__(self, _port_handler, packet_handler, address, _byte):
|
||||||
self.packet_handler = packet_handler
|
self.packet_handler = packet_handler
|
||||||
self.address = address
|
self.address = address
|
||||||
|
|
||||||
|
|||||||
@@ -81,11 +81,11 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
for dataset in [
|
for available_dataset in [
|
||||||
"lerobot/pusht",
|
"lerobot/pusht",
|
||||||
"lerobot/aloha_sim_insertion_human",
|
"lerobot/aloha_sim_insertion_human",
|
||||||
"lerobot/xarm_lift_medium",
|
"lerobot/xarm_lift_medium",
|
||||||
"lerobot/nyu_franka_play_dataset",
|
"lerobot/nyu_franka_play_dataset",
|
||||||
"lerobot/cmu_stretch",
|
"lerobot/cmu_stretch",
|
||||||
]:
|
]:
|
||||||
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors", repo_id=dataset)
|
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors", repo_id=available_dataset)
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ def save_single_transforms(original_frame: torch.Tensor, output_dir: Path):
|
|||||||
}
|
}
|
||||||
|
|
||||||
frames = {"original_frame": original_frame}
|
frames = {"original_frame": original_frame}
|
||||||
for tf_type, tf_name, min_max_values in transforms.items():
|
for tf_type, tf_name, min_max_values in transforms:
|
||||||
for min_max in min_max_values:
|
for min_max in min_max_values:
|
||||||
tf_cfg = ImageTransformConfig(type=tf_type, kwargs={tf_name: min_max})
|
tf_cfg = ImageTransformConfig(type=tf_type, kwargs={tf_name: min_max})
|
||||||
tf = make_transform_from_config(tf_cfg)
|
tf = make_transform_from_config(tf_cfg)
|
||||||
|
|||||||
@@ -150,6 +150,7 @@ def test_camera(request, camera_type, mock):
|
|||||||
else:
|
else:
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
|
manual_rot_img: np.ndarray = None
|
||||||
if rotation is None:
|
if rotation is None:
|
||||||
manual_rot_img = ori_color_image
|
manual_rot_img = ori_color_image
|
||||||
assert camera.rotation is None
|
assert camera.rotation is None
|
||||||
@@ -197,10 +198,14 @@ def test_camera(request, camera_type, mock):
|
|||||||
@require_camera
|
@require_camera
|
||||||
def test_save_images_from_cameras(tmp_path, request, camera_type, mock):
|
def test_save_images_from_cameras(tmp_path, request, camera_type, mock):
|
||||||
# TODO(rcadene): refactor
|
# TODO(rcadene): refactor
|
||||||
|
save_images_from_cameras = None
|
||||||
|
|
||||||
if camera_type == "opencv":
|
if camera_type == "opencv":
|
||||||
from lerobot.common.robot_devices.cameras.opencv import save_images_from_cameras
|
from lerobot.common.robot_devices.cameras.opencv import save_images_from_cameras
|
||||||
elif camera_type == "intelrealsense":
|
elif camera_type == "intelrealsense":
|
||||||
from lerobot.common.robot_devices.cameras.intelrealsense import save_images_from_cameras
|
from lerobot.common.robot_devices.cameras.intelrealsense import save_images_from_cameras
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported camera type: {camera_type}")
|
||||||
|
|
||||||
# Small `record_time_s` to speedup unit tests
|
# Small `record_time_s` to speedup unit tests
|
||||||
save_images_from_cameras(tmp_path, record_time_s=0.02, mock=mock)
|
save_images_from_cameras(tmp_path, record_time_s=0.02, mock=mock)
|
||||||
|
|||||||
@@ -30,12 +30,12 @@ from lerobot.common.datasets.compute_stats import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def mock_load_image_as_numpy(path, dtype, channel_first):
|
def mock_load_image_as_numpy(_path, dtype, channel_first):
|
||||||
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
|
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(name="sample_array")
|
||||||
def sample_array():
|
def fixture_sample_array():
|
||||||
return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||||
|
|
||||||
|
|
||||||
@@ -62,7 +62,7 @@ def test_sample_indices():
|
|||||||
|
|
||||||
|
|
||||||
@patch("lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy)
|
@patch("lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy)
|
||||||
def test_sample_images(mock_load):
|
def test_sample_images(_mock_load):
|
||||||
image_paths = [f"image_{i}.jpg" for i in range(100)]
|
image_paths = [f"image_{i}.jpg" for i in range(100)]
|
||||||
images = sample_images(image_paths)
|
images = sample_images(image_paths)
|
||||||
assert isinstance(images, np.ndarray)
|
assert isinstance(images, np.ndarray)
|
||||||
|
|||||||
@@ -48,8 +48,8 @@ from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
|||||||
from tests.utils import require_x86_64_kernel
|
from tests.utils import require_x86_64_kernel
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(name="image_dataset")
|
||||||
def image_dataset(tmp_path, empty_lerobot_dataset_factory):
|
def fixture_image_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {
|
features = {
|
||||||
"image": {
|
"image": {
|
||||||
"dtype": "image",
|
"dtype": "image",
|
||||||
@@ -374,7 +374,7 @@ def test_factory(env_name, repo_id, policy_name):
|
|||||||
if required:
|
if required:
|
||||||
assert key in item, f"{key}"
|
assert key in item, f"{key}"
|
||||||
else:
|
else:
|
||||||
logging.warning(f'Missing key in dataset: "{key}" not in {dataset}.')
|
logging.warning('Missing key in dataset: "%s" not in %s.', key, dataset)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if delta_timestamps is not None and key in delta_timestamps:
|
if delta_timestamps is not None and key in delta_timestamps:
|
||||||
|
|||||||
@@ -42,7 +42,9 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.n
|
|||||||
table = hf_dataset.data.table
|
table = hf_dataset.data.table
|
||||||
total_episodes = calculate_total_episode(hf_dataset)
|
total_episodes = calculate_total_episode(hf_dataset)
|
||||||
for ep_idx in range(total_episodes):
|
for ep_idx in range(total_episodes):
|
||||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
ep_table = table.filter(
|
||||||
|
pc.equal(table["episode_index"], ep_idx)
|
||||||
|
) # TODO(Steven): What is this check supposed to do?
|
||||||
episode_lengths.insert(ep_idx, len(ep_table))
|
episode_lengths.insert(ep_idx, len(ep_table))
|
||||||
|
|
||||||
cumulative_lengths = list(accumulate(episode_lengths))
|
cumulative_lengths = list(accumulate(episode_lengths))
|
||||||
@@ -52,8 +54,8 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.n
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(name="synced_timestamps_factory", scope="module")
|
||||||
def synced_timestamps_factory(hf_dataset_factory):
|
def fixture_synced_timestamps_factory(hf_dataset_factory):
|
||||||
def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
hf_dataset = hf_dataset_factory(fps=fps)
|
hf_dataset = hf_dataset_factory(fps=fps)
|
||||||
timestamps = torch.stack(hf_dataset["timestamp"]).numpy()
|
timestamps = torch.stack(hf_dataset["timestamp"]).numpy()
|
||||||
@@ -64,8 +66,8 @@ def synced_timestamps_factory(hf_dataset_factory):
|
|||||||
return _create_synced_timestamps
|
return _create_synced_timestamps
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(name="unsynced_timestamps_factory", scope="module")
|
||||||
def unsynced_timestamps_factory(synced_timestamps_factory):
|
def fixture_unsynced_timestamps_factory(synced_timestamps_factory):
|
||||||
def _create_unsynced_timestamps(
|
def _create_unsynced_timestamps(
|
||||||
fps: int = 30, tolerance_s: float = 1e-4
|
fps: int = 30, tolerance_s: float = 1e-4
|
||||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
@@ -76,8 +78,8 @@ def unsynced_timestamps_factory(synced_timestamps_factory):
|
|||||||
return _create_unsynced_timestamps
|
return _create_unsynced_timestamps
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(name="slightly_off_timestamps_factory", scope="module")
|
||||||
def slightly_off_timestamps_factory(synced_timestamps_factory):
|
def fixture_slightly_off_timestamps_factory(synced_timestamps_factory):
|
||||||
def _create_slightly_off_timestamps(
|
def _create_slightly_off_timestamps(
|
||||||
fps: int = 30, tolerance_s: float = 1e-4
|
fps: int = 30, tolerance_s: float = 1e-4
|
||||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
@@ -88,22 +90,26 @@ def slightly_off_timestamps_factory(synced_timestamps_factory):
|
|||||||
return _create_slightly_off_timestamps
|
return _create_slightly_off_timestamps
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(name="valid_delta_timestamps_factory", scope="module")
|
||||||
def valid_delta_timestamps_factory():
|
def fixture_valid_delta_timestamps_factory():
|
||||||
def _create_valid_delta_timestamps(
|
def _create_valid_delta_timestamps(
|
||||||
fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)
|
fps: int = 30, keys: list | None = None, min_max_range: tuple[int, int] = (-10, 10)
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
if keys is None:
|
||||||
|
keys = DUMMY_MOTOR_FEATURES
|
||||||
delta_timestamps = {key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys}
|
delta_timestamps = {key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys}
|
||||||
return delta_timestamps
|
return delta_timestamps
|
||||||
|
|
||||||
return _create_valid_delta_timestamps
|
return _create_valid_delta_timestamps
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(name="invalid_delta_timestamps_factory", scope="module")
|
||||||
def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
def fixture_invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||||
def _create_invalid_delta_timestamps(
|
def _create_invalid_delta_timestamps(
|
||||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
|
fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
if keys is None:
|
||||||
|
keys = DUMMY_MOTOR_FEATURES
|
||||||
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
||||||
# Modify a single timestamp just outside tolerance
|
# Modify a single timestamp just outside tolerance
|
||||||
for key in keys:
|
for key in keys:
|
||||||
@@ -113,11 +119,13 @@ def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
|||||||
return _create_invalid_delta_timestamps
|
return _create_invalid_delta_timestamps
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(name="slightly_off_delta_timestamps_factory", scope="module")
|
||||||
def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
def fixture_slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||||
def _create_slightly_off_delta_timestamps(
|
def _create_slightly_off_delta_timestamps(
|
||||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
|
fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
if keys is None:
|
||||||
|
keys = DUMMY_MOTOR_FEATURES
|
||||||
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
||||||
# Modify a single timestamp just inside tolerance
|
# Modify a single timestamp just inside tolerance
|
||||||
for key in delta_timestamps:
|
for key in delta_timestamps:
|
||||||
@@ -128,9 +136,11 @@ def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
|||||||
return _create_slightly_off_delta_timestamps
|
return _create_slightly_off_delta_timestamps
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(name="delta_indices_factory", scope="module")
|
||||||
def delta_indices_factory():
|
def fixture_delta_indices_factory():
|
||||||
def _delta_indices(keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)) -> dict:
|
def _delta_indices(keys: list | None = None, min_max_range: tuple[int, int] = (-10, 10)) -> dict:
|
||||||
|
if keys is None:
|
||||||
|
keys = DUMMY_MOTOR_FEATURES
|
||||||
return {key: list(range(*min_max_range)) for key in keys}
|
return {key: list(range(*min_max_range)) for key in keys}
|
||||||
|
|
||||||
return _delta_indices
|
return _delta_indices
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ def _run_script(path):
|
|||||||
|
|
||||||
|
|
||||||
def _read_file(path):
|
def _read_file(path):
|
||||||
with open(path) as file:
|
with open(path, encoding="utf-8") as file:
|
||||||
return file.read()
|
return file.read()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -37,8 +37,8 @@ from tests.scripts.save_image_transforms_to_safetensors import ARTIFACT_DIR
|
|||||||
from tests.utils import require_x86_64_kernel
|
from tests.utils import require_x86_64_kernel
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(name="color_jitters")
|
||||||
def color_jitters():
|
def fixture_color_jitters():
|
||||||
return [
|
return [
|
||||||
v2.ColorJitter(brightness=0.5),
|
v2.ColorJitter(brightness=0.5),
|
||||||
v2.ColorJitter(contrast=0.5),
|
v2.ColorJitter(contrast=0.5),
|
||||||
@@ -46,18 +46,18 @@ def color_jitters():
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(name="single_transforms")
|
||||||
def single_transforms():
|
def fixture_single_transforms():
|
||||||
return load_file(ARTIFACT_DIR / "single_transforms.safetensors")
|
return load_file(ARTIFACT_DIR / "single_transforms.safetensors")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(name="single_transforms")
|
||||||
def img_tensor(single_transforms):
|
def fixture_img_tensor(single_transforms):
|
||||||
return single_transforms["original_frame"]
|
return single_transforms["original_frame"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(name="default_transforms")
|
||||||
def default_transforms():
|
def fixture_default_transforms():
|
||||||
return load_file(ARTIFACT_DIR / "default_transforms.safetensors")
|
return load_file(ARTIFACT_DIR / "default_transforms.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ import pytest
|
|||||||
from lerobot.common.utils.io_utils import deserialize_json_into_object
|
from lerobot.common.utils.io_utils import deserialize_json_into_object
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(name="tmp_json_file")
|
||||||
def tmp_json_file(tmp_path: Path):
|
def fixture_tmp_json_file(tmp_path: Path):
|
||||||
"""Writes `data` to a temporary JSON file and returns the file's path."""
|
"""Writes `data` to a temporary JSON file and returns the file's path."""
|
||||||
|
|
||||||
def _write(data: Any) -> Path:
|
def _write(data: Any) -> Path:
|
||||||
|
|||||||
@@ -16,8 +16,8 @@ import pytest
|
|||||||
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
|
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(name="mock_metrics")
|
||||||
def mock_metrics():
|
def fixture_mock_metrics():
|
||||||
return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")}
|
return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")}
|
||||||
|
|
||||||
|
|
||||||
@@ -87,6 +87,7 @@ def test_metrics_tracker_getattr(mock_metrics):
|
|||||||
_ = tracker.non_existent_metric
|
_ = tracker.non_existent_metric
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Steven): I don't understand what's supposed to happen here
|
||||||
def test_metrics_tracker_setattr(mock_metrics):
|
def test_metrics_tracker_setattr(mock_metrics):
|
||||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
||||||
tracker.loss = 2.0
|
tracker.loss = 2.0
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ def test_non_mutate():
|
|||||||
def test_index_error_no_data():
|
def test_index_error_no_data():
|
||||||
buffer, _ = make_new_buffer()
|
buffer, _ = make_new_buffer()
|
||||||
with pytest.raises(IndexError):
|
with pytest.raises(IndexError):
|
||||||
buffer[0]
|
_ = buffer[0]
|
||||||
|
|
||||||
|
|
||||||
def test_index_error_with_data():
|
def test_index_error_with_data():
|
||||||
@@ -83,9 +83,9 @@ def test_index_error_with_data():
|
|||||||
new_data = make_spoof_data_frames(1, n_frames)
|
new_data = make_spoof_data_frames(1, n_frames)
|
||||||
buffer.add_data(new_data)
|
buffer.add_data(new_data)
|
||||||
with pytest.raises(IndexError):
|
with pytest.raises(IndexError):
|
||||||
buffer[n_frames]
|
_ = buffer[n_frames]
|
||||||
with pytest.raises(IndexError):
|
with pytest.raises(IndexError):
|
||||||
buffer[-n_frames - 1]
|
_ = buffer[-n_frames - 1]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("do_reload", [False, True])
|
@pytest.mark.parametrize("do_reload", [False, True])
|
||||||
@@ -185,7 +185,7 @@ def test_delta_timestamps_outside_tolerance_inside_episode_range():
|
|||||||
buffer.add_data(new_data)
|
buffer.add_data(new_data)
|
||||||
buffer.tolerance_s = 0.04
|
buffer.tolerance_s = 0.04
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
buffer[2]
|
_ = buffer[2]
|
||||||
|
|
||||||
|
|
||||||
def test_delta_timestamps_outside_tolerance_outside_episode_range():
|
def test_delta_timestamps_outside_tolerance_outside_episode_range():
|
||||||
@@ -229,6 +229,7 @@ def test_compute_sampler_weights_trivial(
|
|||||||
weights = compute_sampler_weights(
|
weights = compute_sampler_weights(
|
||||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
|
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
|
||||||
)
|
)
|
||||||
|
expected_weights: torch.Tensor = None
|
||||||
if offline_dataset_size == 0 or online_dataset_size == 0:
|
if offline_dataset_size == 0 or online_dataset_size == 0:
|
||||||
expected_weights = torch.ones(offline_dataset_size + online_dataset_size)
|
expected_weights = torch.ones(offline_dataset_size + online_dataset_size)
|
||||||
elif online_sampling_ratio == 0:
|
elif online_sampling_ratio == 0:
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
# pylint: disable=redefined-outer-name, unused-argument
|
||||||
import inspect
|
import inspect
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -251,7 +252,7 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
|
|||||||
policy_cfg.input_features = {
|
policy_cfg.input_features = {
|
||||||
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
|
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
|
||||||
}
|
}
|
||||||
policy = policy_cls(policy_cfg)
|
policy = policy_cls(policy_cfg) # config.device = gpu
|
||||||
save_dir = tmp_path / f"test_save_and_load_pretrained_{policy_cls.__name__}"
|
save_dir = tmp_path / f"test_save_and_load_pretrained_{policy_cls.__name__}"
|
||||||
policy.save_pretrained(save_dir)
|
policy.save_pretrained(save_dir)
|
||||||
policy_ = policy_cls.from_pretrained(save_dir, config=policy_cfg)
|
policy_ = policy_cls.from_pretrained(save_dir, config=policy_cfg)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
# pylint: disable=redefined-outer-name, unused-argument
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|||||||
@@ -32,16 +32,16 @@ from lerobot.common.utils.import_utils import is_package_available
|
|||||||
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu"
|
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
TEST_ROBOT_TYPES = []
|
TEST_ROBOT_TYPES = []
|
||||||
for robot_type in available_robots:
|
for available_robot_type in available_robots:
|
||||||
TEST_ROBOT_TYPES += [(robot_type, True), (robot_type, False)]
|
TEST_ROBOT_TYPES += [(available_robot_type, True), (available_robot_type, False)]
|
||||||
|
|
||||||
TEST_CAMERA_TYPES = []
|
TEST_CAMERA_TYPES = []
|
||||||
for camera_type in available_cameras:
|
for available_camera_type in available_cameras:
|
||||||
TEST_CAMERA_TYPES += [(camera_type, True), (camera_type, False)]
|
TEST_CAMERA_TYPES += [(available_camera_type, True), (available_camera_type, False)]
|
||||||
|
|
||||||
TEST_MOTOR_TYPES = []
|
TEST_MOTOR_TYPES = []
|
||||||
for motor_type in available_motors:
|
for available_motor_type in available_motors:
|
||||||
TEST_MOTOR_TYPES += [(motor_type, True), (motor_type, False)]
|
TEST_MOTOR_TYPES += [(available_motor_type, True), (available_motor_type, False)]
|
||||||
|
|
||||||
# Camera indices used for connecting physical cameras
|
# Camera indices used for connecting physical cameras
|
||||||
OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0))
|
OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0))
|
||||||
@@ -72,7 +72,6 @@ def require_x86_64_kernel(func):
|
|||||||
"""
|
"""
|
||||||
Decorator that skips the test if plateform device is not an x86_64 cpu.
|
Decorator that skips the test if plateform device is not an x86_64 cpu.
|
||||||
"""
|
"""
|
||||||
from functools import wraps
|
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
@@ -87,7 +86,6 @@ def require_cpu(func):
|
|||||||
"""
|
"""
|
||||||
Decorator that skips the test if device is not cpu.
|
Decorator that skips the test if device is not cpu.
|
||||||
"""
|
"""
|
||||||
from functools import wraps
|
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
@@ -102,7 +100,6 @@ def require_cuda(func):
|
|||||||
"""
|
"""
|
||||||
Decorator that skips the test if cuda is not available.
|
Decorator that skips the test if cuda is not available.
|
||||||
"""
|
"""
|
||||||
from functools import wraps
|
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
@@ -288,17 +285,17 @@ def mock_calibration_dir(calibration_dir):
|
|||||||
"motor_names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
"motor_names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||||
}
|
}
|
||||||
Path(str(calibration_dir)).mkdir(parents=True, exist_ok=True)
|
Path(str(calibration_dir)).mkdir(parents=True, exist_ok=True)
|
||||||
with open(calibration_dir / "main_follower.json", "w") as f:
|
with open(calibration_dir / "main_follower.json", "w", encoding="utf-8") as f:
|
||||||
json.dump(example_calib, f)
|
json.dump(example_calib, f)
|
||||||
with open(calibration_dir / "main_leader.json", "w") as f:
|
with open(calibration_dir / "main_leader.json", "w", encoding="utf-8") as f:
|
||||||
json.dump(example_calib, f)
|
json.dump(example_calib, f)
|
||||||
with open(calibration_dir / "left_follower.json", "w") as f:
|
with open(calibration_dir / "left_follower.json", "w", encoding="utf-8") as f:
|
||||||
json.dump(example_calib, f)
|
json.dump(example_calib, f)
|
||||||
with open(calibration_dir / "left_leader.json", "w") as f:
|
with open(calibration_dir / "left_leader.json", "w", encoding="utf-8") as f:
|
||||||
json.dump(example_calib, f)
|
json.dump(example_calib, f)
|
||||||
with open(calibration_dir / "right_follower.json", "w") as f:
|
with open(calibration_dir / "right_follower.json", "w", encoding="utf-8") as f:
|
||||||
json.dump(example_calib, f)
|
json.dump(example_calib, f)
|
||||||
with open(calibration_dir / "right_leader.json", "w") as f:
|
with open(calibration_dir / "right_leader.json", "w", encoding="utf-8") as f:
|
||||||
json.dump(example_calib, f)
|
json.dump(example_calib, f)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user