forked from tangger/lerobot
fix(tests): remove lint warnings/errors
This commit is contained in:
@@ -108,7 +108,7 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu
|
||||
|
||||
|
||||
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
||||
for i in enumerate(stats_list):
|
||||
for i in range(len(stats_list)):
|
||||
for fkey in stats_list[i]:
|
||||
for k, v in stats_list[i][fkey].items():
|
||||
if not isinstance(v, np.ndarray):
|
||||
|
||||
68
tests/fixtures/dataset_factories.py
vendored
68
tests/fixtures/dataset_factories.py
vendored
@@ -52,16 +52,16 @@ def get_task_index(task_dicts: dict, task: str) -> int:
|
||||
return task_to_task_index[task]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_tensor_factory():
|
||||
@pytest.fixture(name="img_tensor_factory", scope="session")
|
||||
def fixture_img_tensor_factory():
|
||||
def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor:
|
||||
return torch.rand((channels, height, width), dtype=dtype)
|
||||
|
||||
return _create_img_tensor
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_array_factory():
|
||||
@pytest.fixture(name="img_array_factory", scope="session")
|
||||
def fixture_img_array_factory():
|
||||
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray:
|
||||
if np.issubdtype(dtype, np.unsignedinteger):
|
||||
# Int array in [0, 255] range
|
||||
@@ -76,8 +76,8 @@ def img_array_factory():
|
||||
return _create_img_array
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_factory(img_array_factory):
|
||||
@pytest.fixture(name="img_factory", scope="session")
|
||||
def fixture_img_factory(img_array_factory):
|
||||
def _create_img(height=100, width=100) -> PIL.Image.Image:
|
||||
img_array = img_array_factory(height=height, width=width)
|
||||
return PIL.Image.fromarray(img_array)
|
||||
@@ -85,13 +85,17 @@ def img_factory(img_array_factory):
|
||||
return _create_img
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def features_factory():
|
||||
@pytest.fixture(name="features_factory", scope="session")
|
||||
def fixture_features_factory():
|
||||
def _create_features(
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
motor_features: dict | None = None,
|
||||
camera_features: dict | None = None,
|
||||
use_videos: bool = True,
|
||||
) -> dict:
|
||||
if motor_features is None:
|
||||
motor_features = DUMMY_MOTOR_FEATURES
|
||||
if camera_features is None:
|
||||
camera_features = DUMMY_CAMERA_FEATURES
|
||||
if use_videos:
|
||||
camera_ft = {
|
||||
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items()
|
||||
@@ -107,8 +111,8 @@ def features_factory():
|
||||
return _create_features
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def info_factory(features_factory):
|
||||
@pytest.fixture(name="info_factory", scope="session")
|
||||
def fixture_info_factory(features_factory):
|
||||
def _create_info(
|
||||
codebase_version: str = CODEBASE_VERSION,
|
||||
fps: int = DEFAULT_FPS,
|
||||
@@ -121,10 +125,14 @@ def info_factory(features_factory):
|
||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||
data_path: str = DEFAULT_PARQUET_PATH,
|
||||
video_path: str = DEFAULT_VIDEO_PATH,
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
motor_features: dict | None = None,
|
||||
camera_features: dict | None = None,
|
||||
use_videos: bool = True,
|
||||
) -> dict:
|
||||
if motor_features is None:
|
||||
motor_features = DUMMY_MOTOR_FEATURES
|
||||
if camera_features is None:
|
||||
camera_features = DUMMY_CAMERA_FEATURES
|
||||
features = features_factory(motor_features, camera_features, use_videos)
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
@@ -145,8 +153,8 @@ def info_factory(features_factory):
|
||||
return _create_info
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def stats_factory():
|
||||
@pytest.fixture(name="stats_factory", scope="session")
|
||||
def fixture_stats_factory():
|
||||
def _create_stats(
|
||||
features: dict[str] | None = None,
|
||||
) -> dict:
|
||||
@@ -175,8 +183,8 @@ def stats_factory():
|
||||
return _create_stats
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episodes_stats_factory(stats_factory):
|
||||
@pytest.fixture(name="episodes_stats_factory", scope="session")
|
||||
def fixture_episodes_stats_factory(stats_factory):
|
||||
def _create_episodes_stats(
|
||||
features: dict[str],
|
||||
total_episodes: int = 3,
|
||||
@@ -192,8 +200,8 @@ def episodes_stats_factory(stats_factory):
|
||||
return _create_episodes_stats
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tasks_factory():
|
||||
@pytest.fixture(name="tasks_factory", scope="session")
|
||||
def fixture_tasks_factory():
|
||||
def _create_tasks(total_tasks: int = 3) -> int:
|
||||
tasks = {}
|
||||
for task_index in range(total_tasks):
|
||||
@@ -204,8 +212,8 @@ def tasks_factory():
|
||||
return _create_tasks
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episodes_factory(tasks_factory):
|
||||
@pytest.fixture(name="episodes_factory", scope="session")
|
||||
def fixture_episodes_factory(tasks_factory):
|
||||
def _create_episodes(
|
||||
total_episodes: int = 3,
|
||||
total_frames: int = 400,
|
||||
@@ -252,8 +260,8 @@ def episodes_factory(tasks_factory):
|
||||
return _create_episodes
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
|
||||
@pytest.fixture(name="hf_dataset_factory", scope="session")
|
||||
def fixture_hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
|
||||
def _create_hf_dataset(
|
||||
features: dict | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
@@ -310,8 +318,8 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
||||
return _create_hf_dataset
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def lerobot_dataset_metadata_factory(
|
||||
@pytest.fixture(name="lerobot_dataset_metadata_factory", scope="session")
|
||||
def fixture_lerobot_dataset_metadata_factory(
|
||||
info_factory,
|
||||
stats_factory,
|
||||
episodes_stats_factory,
|
||||
@@ -364,8 +372,8 @@ def lerobot_dataset_metadata_factory(
|
||||
return _create_lerobot_dataset_metadata
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def lerobot_dataset_factory(
|
||||
@pytest.fixture(name="lerobot_dataset_factory", scope="session")
|
||||
def fixture_lerobot_dataset_factory(
|
||||
info_factory,
|
||||
stats_factory,
|
||||
episodes_stats_factory,
|
||||
@@ -443,6 +451,6 @@ def lerobot_dataset_factory(
|
||||
return _create_lerobot_dataset
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def empty_lerobot_dataset_factory() -> LeRobotDatasetFactory:
|
||||
@pytest.fixture(name="empty_lerobot_dataset_factory", scope="session")
|
||||
def fixture_empty_lerobot_dataset_factory() -> LeRobotDatasetFactory:
|
||||
return partial(LeRobotDataset.create, repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS)
|
||||
|
||||
34
tests/fixtures/files.py
vendored
34
tests/fixtures/files.py
vendored
@@ -31,12 +31,12 @@ from lerobot.common.datasets.utils import (
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def info_path(info_factory):
|
||||
def _create_info_json_file(dir: Path, info: dict | None = None) -> Path:
|
||||
def _create_info_json_file(input_dir: Path, info: dict | None = None) -> Path:
|
||||
if not info:
|
||||
info = info_factory()
|
||||
fpath = dir / INFO_PATH
|
||||
fpath = input_dir / INFO_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(fpath, "w") as f:
|
||||
with open(fpath, "w", encoding="utf-8") as f:
|
||||
json.dump(info, f, indent=4, ensure_ascii=False)
|
||||
return fpath
|
||||
|
||||
@@ -45,12 +45,12 @@ def info_path(info_factory):
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def stats_path(stats_factory):
|
||||
def _create_stats_json_file(dir: Path, stats: dict | None = None) -> Path:
|
||||
def _create_stats_json_file(input_dir: Path, stats: dict | None = None) -> Path:
|
||||
if not stats:
|
||||
stats = stats_factory()
|
||||
fpath = dir / STATS_PATH
|
||||
fpath = input_dir / STATS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(fpath, "w") as f:
|
||||
with open(fpath, "w", encoding="utf-8") as f:
|
||||
json.dump(stats, f, indent=4, ensure_ascii=False)
|
||||
return fpath
|
||||
|
||||
@@ -59,10 +59,10 @@ def stats_path(stats_factory):
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episodes_stats_path(episodes_stats_factory):
|
||||
def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path:
|
||||
def _create_episodes_stats_jsonl_file(input_dir: Path, episodes_stats: list[dict] | None = None) -> Path:
|
||||
if not episodes_stats:
|
||||
episodes_stats = episodes_stats_factory()
|
||||
fpath = dir / EPISODES_STATS_PATH
|
||||
fpath = input_dir / EPISODES_STATS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(episodes_stats.values())
|
||||
@@ -73,10 +73,10 @@ def episodes_stats_path(episodes_stats_factory):
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tasks_path(tasks_factory):
|
||||
def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path:
|
||||
def _create_tasks_jsonl_file(input_dir: Path, tasks: list | None = None) -> Path:
|
||||
if not tasks:
|
||||
tasks = tasks_factory()
|
||||
fpath = dir / TASKS_PATH
|
||||
fpath = input_dir / TASKS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(tasks.values())
|
||||
@@ -87,10 +87,10 @@ def tasks_path(tasks_factory):
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episode_path(episodes_factory):
|
||||
def _create_episodes_jsonl_file(dir: Path, episodes: list | None = None) -> Path:
|
||||
def _create_episodes_jsonl_file(input_dir: Path, episodes: list | None = None) -> Path:
|
||||
if not episodes:
|
||||
episodes = episodes_factory()
|
||||
fpath = dir / EPISODES_PATH
|
||||
fpath = input_dir / EPISODES_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(episodes.values())
|
||||
@@ -102,7 +102,7 @@ def episode_path(episodes_factory):
|
||||
@pytest.fixture(scope="session")
|
||||
def single_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||
def _create_single_episode_parquet(
|
||||
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||
input_dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||
) -> Path:
|
||||
if not info:
|
||||
info = info_factory()
|
||||
@@ -112,7 +112,7 @@ def single_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||
data_path = info["data_path"]
|
||||
chunks_size = info["chunks_size"]
|
||||
ep_chunk = ep_idx // chunks_size
|
||||
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
fpath = input_dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
table = hf_dataset.data.table
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
@@ -125,7 +125,7 @@ def single_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||
@pytest.fixture(scope="session")
|
||||
def multi_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||
def _create_multi_episode_parquet(
|
||||
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||
input_dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||
) -> Path:
|
||||
if not info:
|
||||
info = info_factory()
|
||||
@@ -137,11 +137,11 @@ def multi_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||
total_episodes = info["total_episodes"]
|
||||
for ep_idx in range(total_episodes):
|
||||
ep_chunk = ep_idx // chunks_size
|
||||
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
fpath = input_dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
table = hf_dataset.data.table
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
pq.write_table(ep_table, fpath)
|
||||
return dir / "data"
|
||||
return input_dir / "data"
|
||||
|
||||
return _create_multi_episode_parquet
|
||||
|
||||
6
tests/fixtures/hub.py
vendored
6
tests/fixtures/hub.py
vendored
@@ -81,12 +81,12 @@ def mock_snapshot_download_factory(
|
||||
return None
|
||||
|
||||
def _mock_snapshot_download(
|
||||
repo_id: str,
|
||||
_repo_id: str,
|
||||
*_args,
|
||||
local_dir: str | Path | None = None,
|
||||
allow_patterns: str | list[str] | None = None,
|
||||
ignore_patterns: str | list[str] | None = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
**_kwargs,
|
||||
) -> str:
|
||||
if not local_dir:
|
||||
local_dir = LEROBOT_TEST_DIR
|
||||
|
||||
12
tests/fixtures/optimizers.py
vendored
12
tests/fixtures/optimizers.py
vendored
@@ -18,13 +18,13 @@ from lerobot.common.optim.optimizers import AdamConfig
|
||||
from lerobot.common.optim.schedulers import VQBeTSchedulerConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_params():
|
||||
@pytest.fixture(name="model_params")
|
||||
def fixture_model_params():
|
||||
return [torch.nn.Parameter(torch.randn(10, 10))]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def optimizer(model_params):
|
||||
@pytest.fixture(name="optimizer")
|
||||
def fixture_optimizer(model_params):
|
||||
optimizer = AdamConfig().build(model_params)
|
||||
# Dummy step to populate state
|
||||
loss = sum(param.sum() for param in model_params)
|
||||
@@ -33,7 +33,7 @@ def optimizer(model_params):
|
||||
return optimizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scheduler(optimizer):
|
||||
@pytest.fixture(name="scheduler")
|
||||
def fixture_scheduler(optimizer):
|
||||
config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5)
|
||||
return config.build(optimizer, num_training_steps=100)
|
||||
|
||||
@@ -22,6 +22,8 @@ from lerobot.common.datasets.utils import create_lerobot_dataset_card
|
||||
def test_default_parameters():
|
||||
card = create_lerobot_dataset_card()
|
||||
assert isinstance(card, DatasetCard)
|
||||
# TODO(Steven): Base class CardDate should have 'tags' as a member if we want RepoCard to hold a reference to this abstraction
|
||||
# card.data gives a CardDate type, implementations of this class do have 'tags' but the base class doesn't
|
||||
assert card.data.tags == ["LeRobot"]
|
||||
assert card.data.task_categories == ["robotics"]
|
||||
assert card.data.configs == [
|
||||
|
||||
@@ -52,7 +52,7 @@ def rotate(color_image, rotation):
|
||||
|
||||
|
||||
class VideoCapture:
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
self._mock_dict = {
|
||||
CAP_PROP_FPS: 30,
|
||||
CAP_PROP_FRAME_WIDTH: 640,
|
||||
|
||||
@@ -24,10 +24,9 @@ DEFAULT_BAUDRATE = 9_600
|
||||
COMM_SUCCESS = 0 # tx or rx packet communication success
|
||||
|
||||
|
||||
def convert_to_bytes(value, bytes):
|
||||
def convert_to_bytes(value, _byte):
|
||||
# TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform
|
||||
# `convert_bytes_to_value`
|
||||
del bytes # unused
|
||||
return value
|
||||
|
||||
|
||||
@@ -74,7 +73,7 @@ class PacketHandler:
|
||||
|
||||
|
||||
class GroupSyncRead:
|
||||
def __init__(self, port_handler, packet_handler, address, bytes):
|
||||
def __init__(self, _port_handler, packet_handler, _address, _byte):
|
||||
self.packet_handler = packet_handler
|
||||
|
||||
def addParam(self, motor_index): # noqa: N802
|
||||
@@ -85,12 +84,12 @@ class GroupSyncRead:
|
||||
def txRxPacket(self): # noqa: N802
|
||||
return COMM_SUCCESS
|
||||
|
||||
def getData(self, index, address, bytes): # noqa: N802
|
||||
def getData(self, index, address, _byte): # noqa: N802
|
||||
return self.packet_handler.data[index][address]
|
||||
|
||||
|
||||
class GroupSyncWrite:
|
||||
def __init__(self, port_handler, packet_handler, address, bytes):
|
||||
def __init__(self, _port_handler, packet_handler, address, _byte):
|
||||
self.packet_handler = packet_handler
|
||||
self.address = address
|
||||
|
||||
|
||||
@@ -27,6 +27,13 @@ class format(enum.Enum): # noqa: N801
|
||||
|
||||
|
||||
class config: # noqa: N801
|
||||
device_enabled = None
|
||||
stream_type = None
|
||||
width = None
|
||||
height = None
|
||||
color_format = None
|
||||
fps = None
|
||||
|
||||
def enable_device(self, device_id: str):
|
||||
self.device_enabled = device_id
|
||||
|
||||
@@ -125,8 +132,7 @@ class RSDevice:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_info(self, camera_info) -> str:
|
||||
del camera_info # unused
|
||||
def get_info(self, _camera_info) -> str:
|
||||
# return fake serial number
|
||||
return "123456789"
|
||||
|
||||
@@ -145,4 +151,3 @@ class camera_info: # noqa: N801
|
||||
|
||||
def __init__(self, serial_number):
|
||||
del serial_number
|
||||
pass
|
||||
|
||||
@@ -24,10 +24,10 @@ DEFAULT_BAUDRATE = 1_000_000
|
||||
COMM_SUCCESS = 0 # tx or rx packet communication success
|
||||
|
||||
|
||||
def convert_to_bytes(value, bytes):
|
||||
def convert_to_bytes(value, byte):
|
||||
# TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform
|
||||
# `convert_bytes_to_value`
|
||||
del bytes # unused
|
||||
del byte # unused
|
||||
return value
|
||||
|
||||
|
||||
@@ -85,7 +85,7 @@ class PacketHandler:
|
||||
|
||||
|
||||
class GroupSyncRead:
|
||||
def __init__(self, port_handler, packet_handler, address, bytes):
|
||||
def __init__(self, _port_handler, packet_handler, _address, _byte):
|
||||
self.packet_handler = packet_handler
|
||||
|
||||
def addParam(self, motor_index): # noqa: N802
|
||||
@@ -96,12 +96,12 @@ class GroupSyncRead:
|
||||
def txRxPacket(self): # noqa: N802
|
||||
return COMM_SUCCESS
|
||||
|
||||
def getData(self, index, address, bytes): # noqa: N802
|
||||
def getData(self, index, address, _byte): # noqa: N802
|
||||
return self.packet_handler.data[index][address]
|
||||
|
||||
|
||||
class GroupSyncWrite:
|
||||
def __init__(self, port_handler, packet_handler, address, bytes):
|
||||
def __init__(self, _port_handler, packet_handler, address, _byte):
|
||||
self.packet_handler = packet_handler
|
||||
self.address = address
|
||||
|
||||
|
||||
@@ -81,11 +81,11 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for dataset in [
|
||||
for available_dataset in [
|
||||
"lerobot/pusht",
|
||||
"lerobot/aloha_sim_insertion_human",
|
||||
"lerobot/xarm_lift_medium",
|
||||
"lerobot/nyu_franka_play_dataset",
|
||||
"lerobot/cmu_stretch",
|
||||
]:
|
||||
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors", repo_id=dataset)
|
||||
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors", repo_id=available_dataset)
|
||||
|
||||
@@ -51,7 +51,7 @@ def save_single_transforms(original_frame: torch.Tensor, output_dir: Path):
|
||||
}
|
||||
|
||||
frames = {"original_frame": original_frame}
|
||||
for tf_type, tf_name, min_max_values in transforms.items():
|
||||
for tf_type, tf_name, min_max_values in transforms:
|
||||
for min_max in min_max_values:
|
||||
tf_cfg = ImageTransformConfig(type=tf_type, kwargs={tf_name: min_max})
|
||||
tf = make_transform_from_config(tf_cfg)
|
||||
|
||||
@@ -150,6 +150,7 @@ def test_camera(request, camera_type, mock):
|
||||
else:
|
||||
import cv2
|
||||
|
||||
manual_rot_img: np.ndarray = None
|
||||
if rotation is None:
|
||||
manual_rot_img = ori_color_image
|
||||
assert camera.rotation is None
|
||||
@@ -197,10 +198,14 @@ def test_camera(request, camera_type, mock):
|
||||
@require_camera
|
||||
def test_save_images_from_cameras(tmp_path, request, camera_type, mock):
|
||||
# TODO(rcadene): refactor
|
||||
save_images_from_cameras = None
|
||||
|
||||
if camera_type == "opencv":
|
||||
from lerobot.common.robot_devices.cameras.opencv import save_images_from_cameras
|
||||
elif camera_type == "intelrealsense":
|
||||
from lerobot.common.robot_devices.cameras.intelrealsense import save_images_from_cameras
|
||||
else:
|
||||
raise ValueError(f"Unsupported camera type: {camera_type}")
|
||||
|
||||
# Small `record_time_s` to speedup unit tests
|
||||
save_images_from_cameras(tmp_path, record_time_s=0.02, mock=mock)
|
||||
|
||||
@@ -30,12 +30,12 @@ from lerobot.common.datasets.compute_stats import (
|
||||
)
|
||||
|
||||
|
||||
def mock_load_image_as_numpy(path, dtype, channel_first):
|
||||
def mock_load_image_as_numpy(_path, dtype, channel_first):
|
||||
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_array():
|
||||
@pytest.fixture(name="sample_array")
|
||||
def fixture_sample_array():
|
||||
return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ def test_sample_indices():
|
||||
|
||||
|
||||
@patch("lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy)
|
||||
def test_sample_images(mock_load):
|
||||
def test_sample_images(_mock_load):
|
||||
image_paths = [f"image_{i}.jpg" for i in range(100)]
|
||||
images = sample_images(image_paths)
|
||||
assert isinstance(images, np.ndarray)
|
||||
|
||||
@@ -48,8 +48,8 @@ from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.utils import require_x86_64_kernel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def image_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||
@pytest.fixture(name="image_dataset")
|
||||
def fixture_image_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {
|
||||
"image": {
|
||||
"dtype": "image",
|
||||
@@ -374,7 +374,7 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
if required:
|
||||
assert key in item, f"{key}"
|
||||
else:
|
||||
logging.warning(f'Missing key in dataset: "{key}" not in {dataset}.')
|
||||
logging.warning('Missing key in dataset: "%s" not in %s.', key, dataset)
|
||||
continue
|
||||
|
||||
if delta_timestamps is not None and key in delta_timestamps:
|
||||
|
||||
@@ -42,7 +42,9 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.n
|
||||
table = hf_dataset.data.table
|
||||
total_episodes = calculate_total_episode(hf_dataset)
|
||||
for ep_idx in range(total_episodes):
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
ep_table = table.filter(
|
||||
pc.equal(table["episode_index"], ep_idx)
|
||||
) # TODO(Steven): What is this check supposed to do?
|
||||
episode_lengths.insert(ep_idx, len(ep_table))
|
||||
|
||||
cumulative_lengths = list(accumulate(episode_lengths))
|
||||
@@ -52,8 +54,8 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.n
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def synced_timestamps_factory(hf_dataset_factory):
|
||||
@pytest.fixture(name="synced_timestamps_factory", scope="module")
|
||||
def fixture_synced_timestamps_factory(hf_dataset_factory):
|
||||
def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
hf_dataset = hf_dataset_factory(fps=fps)
|
||||
timestamps = torch.stack(hf_dataset["timestamp"]).numpy()
|
||||
@@ -64,8 +66,8 @@ def synced_timestamps_factory(hf_dataset_factory):
|
||||
return _create_synced_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def unsynced_timestamps_factory(synced_timestamps_factory):
|
||||
@pytest.fixture(name="unsynced_timestamps_factory", scope="module")
|
||||
def fixture_unsynced_timestamps_factory(synced_timestamps_factory):
|
||||
def _create_unsynced_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
@@ -76,8 +78,8 @@ def unsynced_timestamps_factory(synced_timestamps_factory):
|
||||
return _create_unsynced_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def slightly_off_timestamps_factory(synced_timestamps_factory):
|
||||
@pytest.fixture(name="slightly_off_timestamps_factory", scope="module")
|
||||
def fixture_slightly_off_timestamps_factory(synced_timestamps_factory):
|
||||
def _create_slightly_off_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
@@ -88,22 +90,26 @@ def slightly_off_timestamps_factory(synced_timestamps_factory):
|
||||
return _create_slightly_off_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def valid_delta_timestamps_factory():
|
||||
@pytest.fixture(name="valid_delta_timestamps_factory", scope="module")
|
||||
def fixture_valid_delta_timestamps_factory():
|
||||
def _create_valid_delta_timestamps(
|
||||
fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)
|
||||
fps: int = 30, keys: list | None = None, min_max_range: tuple[int, int] = (-10, 10)
|
||||
) -> dict:
|
||||
if keys is None:
|
||||
keys = DUMMY_MOTOR_FEATURES
|
||||
delta_timestamps = {key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys}
|
||||
return delta_timestamps
|
||||
|
||||
return _create_valid_delta_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
@pytest.fixture(name="invalid_delta_timestamps_factory", scope="module")
|
||||
def fixture_invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
def _create_invalid_delta_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None
|
||||
) -> dict:
|
||||
if keys is None:
|
||||
keys = DUMMY_MOTOR_FEATURES
|
||||
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
||||
# Modify a single timestamp just outside tolerance
|
||||
for key in keys:
|
||||
@@ -113,11 +119,13 @@ def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
return _create_invalid_delta_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
@pytest.fixture(name="slightly_off_delta_timestamps_factory", scope="module")
|
||||
def fixture_slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
def _create_slightly_off_delta_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None
|
||||
) -> dict:
|
||||
if keys is None:
|
||||
keys = DUMMY_MOTOR_FEATURES
|
||||
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
||||
# Modify a single timestamp just inside tolerance
|
||||
for key in delta_timestamps:
|
||||
@@ -128,9 +136,11 @@ def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
return _create_slightly_off_delta_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def delta_indices_factory():
|
||||
def _delta_indices(keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)) -> dict:
|
||||
@pytest.fixture(name="delta_indices_factory", scope="module")
|
||||
def fixture_delta_indices_factory():
|
||||
def _delta_indices(keys: list | None = None, min_max_range: tuple[int, int] = (-10, 10)) -> dict:
|
||||
if keys is None:
|
||||
keys = DUMMY_MOTOR_FEATURES
|
||||
return {key: list(range(*min_max_range)) for key in keys}
|
||||
|
||||
return _delta_indices
|
||||
|
||||
@@ -38,7 +38,7 @@ def _run_script(path):
|
||||
|
||||
|
||||
def _read_file(path):
|
||||
with open(path) as file:
|
||||
with open(path, encoding="utf-8") as file:
|
||||
return file.read()
|
||||
|
||||
|
||||
|
||||
@@ -37,8 +37,8 @@ from tests.scripts.save_image_transforms_to_safetensors import ARTIFACT_DIR
|
||||
from tests.utils import require_x86_64_kernel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def color_jitters():
|
||||
@pytest.fixture(name="color_jitters")
|
||||
def fixture_color_jitters():
|
||||
return [
|
||||
v2.ColorJitter(brightness=0.5),
|
||||
v2.ColorJitter(contrast=0.5),
|
||||
@@ -46,18 +46,18 @@ def color_jitters():
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def single_transforms():
|
||||
@pytest.fixture(name="single_transforms")
|
||||
def fixture_single_transforms():
|
||||
return load_file(ARTIFACT_DIR / "single_transforms.safetensors")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def img_tensor(single_transforms):
|
||||
@pytest.fixture(name="single_transforms")
|
||||
def fixture_img_tensor(single_transforms):
|
||||
return single_transforms["original_frame"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_transforms():
|
||||
@pytest.fixture(name="default_transforms")
|
||||
def fixture_default_transforms():
|
||||
return load_file(ARTIFACT_DIR / "default_transforms.safetensors")
|
||||
|
||||
|
||||
|
||||
@@ -20,8 +20,8 @@ import pytest
|
||||
from lerobot.common.utils.io_utils import deserialize_json_into_object
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_json_file(tmp_path: Path):
|
||||
@pytest.fixture(name="tmp_json_file")
|
||||
def fixture_tmp_json_file(tmp_path: Path):
|
||||
"""Writes `data` to a temporary JSON file and returns the file's path."""
|
||||
|
||||
def _write(data: Any) -> Path:
|
||||
|
||||
@@ -16,8 +16,8 @@ import pytest
|
||||
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_metrics():
|
||||
@pytest.fixture(name="mock_metrics")
|
||||
def fixture_mock_metrics():
|
||||
return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")}
|
||||
|
||||
|
||||
@@ -87,6 +87,7 @@ def test_metrics_tracker_getattr(mock_metrics):
|
||||
_ = tracker.non_existent_metric
|
||||
|
||||
|
||||
# TODO(Steven): I don't understand what's supposed to happen here
|
||||
def test_metrics_tracker_setattr(mock_metrics):
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
||||
tracker.loss = 2.0
|
||||
|
||||
@@ -74,7 +74,7 @@ def test_non_mutate():
|
||||
def test_index_error_no_data():
|
||||
buffer, _ = make_new_buffer()
|
||||
with pytest.raises(IndexError):
|
||||
buffer[0]
|
||||
_ = buffer[0]
|
||||
|
||||
|
||||
def test_index_error_with_data():
|
||||
@@ -83,9 +83,9 @@ def test_index_error_with_data():
|
||||
new_data = make_spoof_data_frames(1, n_frames)
|
||||
buffer.add_data(new_data)
|
||||
with pytest.raises(IndexError):
|
||||
buffer[n_frames]
|
||||
_ = buffer[n_frames]
|
||||
with pytest.raises(IndexError):
|
||||
buffer[-n_frames - 1]
|
||||
_ = buffer[-n_frames - 1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("do_reload", [False, True])
|
||||
@@ -185,7 +185,7 @@ def test_delta_timestamps_outside_tolerance_inside_episode_range():
|
||||
buffer.add_data(new_data)
|
||||
buffer.tolerance_s = 0.04
|
||||
with pytest.raises(AssertionError):
|
||||
buffer[2]
|
||||
_ = buffer[2]
|
||||
|
||||
|
||||
def test_delta_timestamps_outside_tolerance_outside_episode_range():
|
||||
@@ -229,6 +229,7 @@ def test_compute_sampler_weights_trivial(
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
|
||||
)
|
||||
expected_weights: torch.Tensor = None
|
||||
if offline_dataset_size == 0 or online_dataset_size == 0:
|
||||
expected_weights = torch.ones(offline_dataset_size + online_dataset_size)
|
||||
elif online_sampling_ratio == 0:
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# pylint: disable=redefined-outer-name, unused-argument
|
||||
import inspect
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
@@ -251,7 +252,7 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
|
||||
policy_cfg.input_features = {
|
||||
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
|
||||
}
|
||||
policy = policy_cls(policy_cfg)
|
||||
policy = policy_cls(policy_cfg) # config.device = gpu
|
||||
save_dir = tmp_path / f"test_save_and_load_pretrained_{policy_cls.__name__}"
|
||||
policy.save_pretrained(save_dir)
|
||||
policy_ = policy_cls.from_pretrained(save_dir, config=policy_cfg)
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# pylint: disable=redefined-outer-name, unused-argument
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -32,16 +32,16 @@ from lerobot.common.utils.import_utils import is_package_available
|
||||
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu"
|
||||
|
||||
TEST_ROBOT_TYPES = []
|
||||
for robot_type in available_robots:
|
||||
TEST_ROBOT_TYPES += [(robot_type, True), (robot_type, False)]
|
||||
for available_robot_type in available_robots:
|
||||
TEST_ROBOT_TYPES += [(available_robot_type, True), (available_robot_type, False)]
|
||||
|
||||
TEST_CAMERA_TYPES = []
|
||||
for camera_type in available_cameras:
|
||||
TEST_CAMERA_TYPES += [(camera_type, True), (camera_type, False)]
|
||||
for available_camera_type in available_cameras:
|
||||
TEST_CAMERA_TYPES += [(available_camera_type, True), (available_camera_type, False)]
|
||||
|
||||
TEST_MOTOR_TYPES = []
|
||||
for motor_type in available_motors:
|
||||
TEST_MOTOR_TYPES += [(motor_type, True), (motor_type, False)]
|
||||
for available_motor_type in available_motors:
|
||||
TEST_MOTOR_TYPES += [(available_motor_type, True), (available_motor_type, False)]
|
||||
|
||||
# Camera indices used for connecting physical cameras
|
||||
OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0))
|
||||
@@ -72,7 +72,6 @@ def require_x86_64_kernel(func):
|
||||
"""
|
||||
Decorator that skips the test if plateform device is not an x86_64 cpu.
|
||||
"""
|
||||
from functools import wraps
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
@@ -87,7 +86,6 @@ def require_cpu(func):
|
||||
"""
|
||||
Decorator that skips the test if device is not cpu.
|
||||
"""
|
||||
from functools import wraps
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
@@ -102,7 +100,6 @@ def require_cuda(func):
|
||||
"""
|
||||
Decorator that skips the test if cuda is not available.
|
||||
"""
|
||||
from functools import wraps
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
@@ -288,17 +285,17 @@ def mock_calibration_dir(calibration_dir):
|
||||
"motor_names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||
}
|
||||
Path(str(calibration_dir)).mkdir(parents=True, exist_ok=True)
|
||||
with open(calibration_dir / "main_follower.json", "w") as f:
|
||||
with open(calibration_dir / "main_follower.json", "w", encoding="utf-8") as f:
|
||||
json.dump(example_calib, f)
|
||||
with open(calibration_dir / "main_leader.json", "w") as f:
|
||||
with open(calibration_dir / "main_leader.json", "w", encoding="utf-8") as f:
|
||||
json.dump(example_calib, f)
|
||||
with open(calibration_dir / "left_follower.json", "w") as f:
|
||||
with open(calibration_dir / "left_follower.json", "w", encoding="utf-8") as f:
|
||||
json.dump(example_calib, f)
|
||||
with open(calibration_dir / "left_leader.json", "w") as f:
|
||||
with open(calibration_dir / "left_leader.json", "w", encoding="utf-8") as f:
|
||||
json.dump(example_calib, f)
|
||||
with open(calibration_dir / "right_follower.json", "w") as f:
|
||||
with open(calibration_dir / "right_follower.json", "w", encoding="utf-8") as f:
|
||||
json.dump(example_calib, f)
|
||||
with open(calibration_dir / "right_leader.json", "w") as f:
|
||||
with open(calibration_dir / "right_leader.json", "w", encoding="utf-8") as f:
|
||||
json.dump(example_calib, f)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user