forked from tangger/lerobot
Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots
This commit is contained in:
@@ -0,0 +1,13 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
89
tests/configs/test_plugin_loading.py
Normal file
89
tests/configs/test_plugin_loading.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.envs.configs import EnvConfig
|
||||
from lerobot.configs.parser import PluginLoadError, load_plugin, parse_plugin_args, wrap
|
||||
|
||||
|
||||
def create_plugin_code(*, base_class: str = "EnvConfig", plugin_name: str = "test_env") -> str:
|
||||
"""Creates a dummy plugin module that implements its own EnvConfig subclass."""
|
||||
return f"""
|
||||
from dataclasses import dataclass
|
||||
from lerobot.common.envs.configs import {base_class}
|
||||
|
||||
@{base_class}.register_subclass("{plugin_name}")
|
||||
@dataclass
|
||||
class TestPluginConfig:
|
||||
value: int = 42
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def plugin_dir(tmp_path: Path) -> Generator[Path, None, None]:
|
||||
"""Creates a temporary plugin package structure."""
|
||||
plugin_pkg = tmp_path / "test_plugin"
|
||||
plugin_pkg.mkdir()
|
||||
(plugin_pkg / "__init__.py").touch()
|
||||
|
||||
with open(plugin_pkg / "my_plugin.py", "w") as f:
|
||||
f.write(create_plugin_code())
|
||||
|
||||
# Add tmp_path to Python path so we can import from it
|
||||
sys.path.insert(0, str(tmp_path))
|
||||
yield plugin_pkg
|
||||
sys.path.pop(0)
|
||||
|
||||
|
||||
def test_parse_plugin_args():
|
||||
cli_args = [
|
||||
"--env.type=test",
|
||||
"--model.discover_packages_path=some.package",
|
||||
"--env.discover_packages_path=other.package",
|
||||
]
|
||||
plugin_args = parse_plugin_args("discover_packages_path", cli_args)
|
||||
assert plugin_args == {
|
||||
"model.discover_packages_path": "some.package",
|
||||
"env.discover_packages_path": "other.package",
|
||||
}
|
||||
|
||||
|
||||
def test_load_plugin_success(plugin_dir: Path):
|
||||
# Import should work and register the plugin with the real EnvConfig
|
||||
load_plugin("test_plugin")
|
||||
|
||||
assert "test_env" in EnvConfig.get_known_choices()
|
||||
plugin_cls = EnvConfig.get_choice_class("test_env")
|
||||
plugin_instance = plugin_cls()
|
||||
assert plugin_instance.value == 42
|
||||
|
||||
|
||||
def test_load_plugin_failure():
|
||||
with pytest.raises(PluginLoadError) as exc_info:
|
||||
load_plugin("nonexistent_plugin")
|
||||
assert "Failed to load plugin 'nonexistent_plugin'" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_wrap_with_plugin(plugin_dir: Path):
|
||||
@dataclass
|
||||
class Config:
|
||||
env: EnvConfig
|
||||
|
||||
@wrap()
|
||||
def dummy_func(cfg: Config):
|
||||
return cfg
|
||||
|
||||
# Test loading plugin via CLI args
|
||||
sys.argv = [
|
||||
"dummy_script.py",
|
||||
"--env.discover_packages_path=test_plugin",
|
||||
"--env.type=test_env",
|
||||
]
|
||||
|
||||
cfg = dummy_func()
|
||||
assert isinstance(cfg, Config)
|
||||
assert isinstance(cfg.env, EnvConfig.get_choice_class("test_env"))
|
||||
assert cfg.env.value == 42
|
||||
@@ -36,51 +36,27 @@ def pytest_collection_finish():
|
||||
print(f"\nTesting with {DEVICE=}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def is_robot_available(robot_type):
|
||||
if robot_type not in available_robots:
|
||||
def _check_component_availability(component_type, available_components, make_component):
|
||||
"""Generic helper to check if a hardware component is available"""
|
||||
if component_type not in available_components:
|
||||
raise ValueError(
|
||||
f"The robot type '{robot_type}' is not valid. Expected one of these '{available_robots}"
|
||||
f"The {component_type} type is not valid. Expected one of these '{available_components}'"
|
||||
)
|
||||
|
||||
try:
|
||||
robot = make_robot(robot_type)
|
||||
robot.connect()
|
||||
del robot
|
||||
component = make_component(component_type)
|
||||
component.connect()
|
||||
del component
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"\nA {robot_type} robot is not available.")
|
||||
print(f"\nA {component_type} is not available.")
|
||||
|
||||
if isinstance(e, ModuleNotFoundError):
|
||||
print(f"\nInstall module '{e.name}'")
|
||||
elif isinstance(e, SerialException):
|
||||
print("\nNo physical motors bus detected.")
|
||||
else:
|
||||
traceback.print_exc()
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def is_camera_available(camera_type):
|
||||
if camera_type not in available_cameras:
|
||||
raise ValueError(
|
||||
f"The camera type '{camera_type}' is not valid. Expected one of these '{available_cameras}"
|
||||
)
|
||||
|
||||
try:
|
||||
camera = make_camera(camera_type)
|
||||
camera.connect()
|
||||
del camera
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"\nA {camera_type} camera is not available.")
|
||||
|
||||
if isinstance(e, ModuleNotFoundError):
|
||||
print(f"\nInstall module '{e.name}'")
|
||||
elif isinstance(e, ValueError) and "camera_index" in e.args[0]:
|
||||
print("\nNo physical device detected.")
|
||||
elif isinstance(e, ValueError) and "camera_index" in str(e):
|
||||
print("\nNo physical camera detected.")
|
||||
else:
|
||||
traceback.print_exc()
|
||||
@@ -88,30 +64,19 @@ def is_camera_available(camera_type):
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def is_robot_available(robot_type):
|
||||
return _check_component_availability(robot_type, available_robots, make_robot)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def is_camera_available(camera_type):
|
||||
return _check_component_availability(camera_type, available_cameras, make_camera)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def is_motor_available(motor_type):
|
||||
if motor_type not in available_motors:
|
||||
raise ValueError(
|
||||
f"The motor type '{motor_type}' is not valid. Expected one of these '{available_motors}"
|
||||
)
|
||||
|
||||
try:
|
||||
motors_bus = make_motors_bus(motor_type)
|
||||
motors_bus.connect()
|
||||
del motors_bus
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"\nA {motor_type} motor is not available.")
|
||||
|
||||
if isinstance(e, ModuleNotFoundError):
|
||||
print(f"\nInstall module '{e.name}'")
|
||||
elif isinstance(e, SerialException):
|
||||
print("\nNo physical motors bus detected.")
|
||||
else:
|
||||
traceback.print_exc()
|
||||
|
||||
return False
|
||||
return _check_component_availability(motor_type, available_motors, make_motors_bus)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
13
tests/fixtures/constants.py
vendored
13
tests/fixtures/constants.py
vendored
@@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
|
||||
LEROBOT_TEST_DIR = HF_LEROBOT_HOME / "_testing"
|
||||
|
||||
13
tests/fixtures/dataset_factories.py
vendored
13
tests/fixtures/dataset_factories.py
vendored
@@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
13
tests/fixtures/files.py
vendored
13
tests/fixtures/files.py
vendored
@@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
13
tests/fixtures/hub.py
vendored
13
tests/fixtures/hub.py
vendored
@@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
|
||||
13
tests/fixtures/optimizers.py
vendored
13
tests/fixtures/optimizers.py
vendored
@@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
@@ -1,7 +1,25 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import cache
|
||||
|
||||
import numpy as np
|
||||
|
||||
CAP_V4L2 = 200
|
||||
CAP_DSHOW = 700
|
||||
CAP_AVFOUNDATION = 1200
|
||||
CAP_ANY = -1
|
||||
|
||||
CAP_PROP_FPS = 5
|
||||
CAP_PROP_FRAME_WIDTH = 3
|
||||
CAP_PROP_FRAME_HEIGHT = 4
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Mocked classes and functions from dynamixel_sdk to allow for continuous integration
|
||||
and testing code logic that requires hardware and devices (e.g. robot arms, cameras)
|
||||
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import enum
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Mocked classes and functions from dynamixel_sdk to allow for continuous integration
|
||||
and testing code logic that requires hardware and devices (e.g. robot arms, cameras)
|
||||
|
||||
|
||||
@@ -33,12 +33,11 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
|
||||
policy=make_policy_config(policy_name, **policy_kwargs),
|
||||
device="cpu",
|
||||
)
|
||||
train_cfg.validate() # Needed for auto-setting some parameters
|
||||
|
||||
dataset = make_dataset(train_cfg)
|
||||
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta, device=train_cfg.device)
|
||||
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
|
||||
policy.train()
|
||||
|
||||
optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy)
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Tests for physical cameras and their mocked versions.
|
||||
If the physical camera is not connected to the computer, or not working,
|
||||
@@ -72,8 +85,8 @@ def test_camera(request, camera_type, mock):
|
||||
camera.connect()
|
||||
assert camera.is_connected
|
||||
assert camera.fps is not None
|
||||
assert camera.width is not None
|
||||
assert camera.height is not None
|
||||
assert camera.capture_width is not None
|
||||
assert camera.capture_height is not None
|
||||
|
||||
# Test connecting twice raises an error
|
||||
with pytest.raises(DeviceAlreadyConnectedError):
|
||||
@@ -191,3 +204,49 @@ def test_save_images_from_cameras(tmp_path, request, camera_type, mock):
|
||||
|
||||
# Small `record_time_s` to speedup unit tests
|
||||
save_images_from_cameras(tmp_path, record_time_s=0.02, mock=mock)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("camera_type, mock", TEST_CAMERA_TYPES)
|
||||
@require_camera
|
||||
def test_camera_rotation(request, camera_type, mock):
|
||||
config_kwargs = {"camera_type": camera_type, "mock": mock, "width": 640, "height": 480, "fps": 30}
|
||||
|
||||
# No rotation.
|
||||
camera = make_camera(**config_kwargs, rotation=None)
|
||||
camera.connect()
|
||||
assert camera.capture_width == 640
|
||||
assert camera.capture_height == 480
|
||||
assert camera.width == 640
|
||||
assert camera.height == 480
|
||||
no_rot_img = camera.read()
|
||||
h, w, c = no_rot_img.shape
|
||||
assert h == 480 and w == 640 and c == 3
|
||||
camera.disconnect()
|
||||
|
||||
# Rotation = 90 (clockwise).
|
||||
camera = make_camera(**config_kwargs, rotation=90)
|
||||
camera.connect()
|
||||
# With a 90° rotation, we expect the metadata dimensions to be swapped.
|
||||
assert camera.capture_width == 640
|
||||
assert camera.capture_height == 480
|
||||
assert camera.width == 480
|
||||
assert camera.height == 640
|
||||
import cv2
|
||||
|
||||
assert camera.rotation == cv2.ROTATE_90_CLOCKWISE
|
||||
rot_img = camera.read()
|
||||
h, w, c = rot_img.shape
|
||||
assert h == 640 and w == 480 and c == 3
|
||||
camera.disconnect()
|
||||
|
||||
# Rotation = 180.
|
||||
camera = make_camera(**config_kwargs, rotation=None)
|
||||
camera.connect()
|
||||
assert camera.capture_width == 640
|
||||
assert camera.capture_height == 480
|
||||
assert camera.width == 640
|
||||
assert camera.height == 480
|
||||
no_rot_img = camera.read()
|
||||
h, w, c = no_rot_img.shape
|
||||
assert h == 480 and w == 640 and c == 3
|
||||
camera.disconnect()
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Tests for physical robots and their mocked versions.
|
||||
If the physical robots are not connected to the computer, or not working,
|
||||
@@ -39,7 +52,7 @@ from lerobot.configs.control import (
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
|
||||
from tests.test_robots import make_robot
|
||||
from tests.utils import DEVICE, TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
|
||||
from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
|
||||
|
||||
|
||||
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
|
||||
@@ -171,7 +184,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
|
||||
replay(robot, replay_cfg)
|
||||
|
||||
policy_cfg = ACTConfig()
|
||||
policy = make_policy(policy_cfg, ds_meta=dataset.meta, device=DEVICE)
|
||||
policy = make_policy(policy_cfg, ds_meta=dataset.meta)
|
||||
|
||||
out_dir = tmp_path / "logger"
|
||||
|
||||
@@ -216,8 +229,6 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
num_image_writer_processes=num_image_writer_processes,
|
||||
device=DEVICE,
|
||||
use_amp=False,
|
||||
)
|
||||
|
||||
rec_eval_cfg.policy = PreTrainedConfig.from_pretrained(pretrained_policy_path)
|
||||
|
||||
@@ -45,7 +45,7 @@ from lerobot.common.robots.utils import make_robot
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.utils import DEVICE, require_x86_64_kernel
|
||||
from tests.utils import require_x86_64_kernel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -349,7 +349,6 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
dataset=DatasetConfig(repo_id=repo_id, episodes=[0]),
|
||||
env=make_env_config(env_name),
|
||||
policy=make_policy_config(policy_name),
|
||||
device=DEVICE,
|
||||
)
|
||||
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from itertools import accumulate
|
||||
|
||||
import datasets
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import queue
|
||||
import time
|
||||
from multiprocessing import queues
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import pytest
|
||||
|
||||
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Tests for physical motors and their mocked versions.
|
||||
If the physical motors are not connected to the computer, or not working,
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
@@ -143,12 +143,11 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
||||
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
|
||||
policy=make_policy_config(policy_name, **policy_kwargs),
|
||||
env=make_env_config(env_name, **env_kwargs),
|
||||
device=DEVICE,
|
||||
)
|
||||
|
||||
# Check that we can make the policy object.
|
||||
dataset = make_dataset(train_cfg)
|
||||
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta, device=DEVICE)
|
||||
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
|
||||
assert isinstance(policy, PreTrainedPolicy)
|
||||
|
||||
# Check that we run select_actions and get the appropriate output.
|
||||
@@ -214,7 +213,6 @@ def test_act_backbone_lr():
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
|
||||
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001),
|
||||
device=DEVICE,
|
||||
)
|
||||
cfg.validate() # Needed for auto-setting some parameters
|
||||
|
||||
@@ -222,7 +220,7 @@ def test_act_backbone_lr():
|
||||
assert cfg.policy.optimizer_lr_backbone == 0.001
|
||||
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(cfg.policy, device=DEVICE, ds_meta=dataset.meta)
|
||||
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
|
||||
assert len(optimizer.param_groups) == 2
|
||||
assert optimizer.param_groups[0]["lr"] == cfg.policy.optimizer_lr
|
||||
@@ -254,10 +252,11 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
|
||||
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
|
||||
}
|
||||
policy = policy_cls(policy_cfg)
|
||||
policy.to(policy_cfg.device)
|
||||
save_dir = tmp_path / f"test_save_and_load_pretrained_{policy_cls.__name__}"
|
||||
policy.save_pretrained(save_dir)
|
||||
policy_ = policy_cls.from_pretrained(save_dir, config=policy_cfg)
|
||||
assert all(torch.equal(p, p_) for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True))
|
||||
loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg)
|
||||
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("insert_temporal_dim", [False, True])
|
||||
@@ -369,7 +368,7 @@ def test_normalize(insert_temporal_dim):
|
||||
# was changed to true. For some reason, tests would pass locally, but not in CI. So here we override
|
||||
# to test with `policy.use_mpc=false`.
|
||||
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"),
|
||||
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"),
|
||||
# ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"),
|
||||
# TODO(rcadene): the diffusion model was normalizing the image in mean=0.5 std=0.5 which is a hack supposed to
|
||||
# to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference
|
||||
# that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass.
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Tests for physical robots and their mocked versions.
|
||||
If the physical robots are not connected to the computer, or not working,
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
from lerobot.common.constants import SCHEDULER_STATE
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
|
||||
Reference in New Issue
Block a user