Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots

This commit is contained in:
Simon Alibert
2025-03-10 18:39:48 +01:00
135 changed files with 2177 additions and 514 deletions

View File

@@ -0,0 +1,13 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,89 @@
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Generator
import pytest
from lerobot.common.envs.configs import EnvConfig
from lerobot.configs.parser import PluginLoadError, load_plugin, parse_plugin_args, wrap
def create_plugin_code(*, base_class: str = "EnvConfig", plugin_name: str = "test_env") -> str:
"""Creates a dummy plugin module that implements its own EnvConfig subclass."""
return f"""
from dataclasses import dataclass
from lerobot.common.envs.configs import {base_class}
@{base_class}.register_subclass("{plugin_name}")
@dataclass
class TestPluginConfig:
value: int = 42
"""
@pytest.fixture
def plugin_dir(tmp_path: Path) -> Generator[Path, None, None]:
"""Creates a temporary plugin package structure."""
plugin_pkg = tmp_path / "test_plugin"
plugin_pkg.mkdir()
(plugin_pkg / "__init__.py").touch()
with open(plugin_pkg / "my_plugin.py", "w") as f:
f.write(create_plugin_code())
# Add tmp_path to Python path so we can import from it
sys.path.insert(0, str(tmp_path))
yield plugin_pkg
sys.path.pop(0)
def test_parse_plugin_args():
cli_args = [
"--env.type=test",
"--model.discover_packages_path=some.package",
"--env.discover_packages_path=other.package",
]
plugin_args = parse_plugin_args("discover_packages_path", cli_args)
assert plugin_args == {
"model.discover_packages_path": "some.package",
"env.discover_packages_path": "other.package",
}
def test_load_plugin_success(plugin_dir: Path):
# Import should work and register the plugin with the real EnvConfig
load_plugin("test_plugin")
assert "test_env" in EnvConfig.get_known_choices()
plugin_cls = EnvConfig.get_choice_class("test_env")
plugin_instance = plugin_cls()
assert plugin_instance.value == 42
def test_load_plugin_failure():
with pytest.raises(PluginLoadError) as exc_info:
load_plugin("nonexistent_plugin")
assert "Failed to load plugin 'nonexistent_plugin'" in str(exc_info.value)
def test_wrap_with_plugin(plugin_dir: Path):
@dataclass
class Config:
env: EnvConfig
@wrap()
def dummy_func(cfg: Config):
return cfg
# Test loading plugin via CLI args
sys.argv = [
"dummy_script.py",
"--env.discover_packages_path=test_plugin",
"--env.type=test_env",
]
cfg = dummy_func()
assert isinstance(cfg, Config)
assert isinstance(cfg.env, EnvConfig.get_choice_class("test_env"))
assert cfg.env.value == 42

View File

@@ -36,51 +36,27 @@ def pytest_collection_finish():
print(f"\nTesting with {DEVICE=}")
@pytest.fixture
def is_robot_available(robot_type):
if robot_type not in available_robots:
def _check_component_availability(component_type, available_components, make_component):
"""Generic helper to check if a hardware component is available"""
if component_type not in available_components:
raise ValueError(
f"The robot type '{robot_type}' is not valid. Expected one of these '{available_robots}"
f"The {component_type} type is not valid. Expected one of these '{available_components}'"
)
try:
robot = make_robot(robot_type)
robot.connect()
del robot
component = make_component(component_type)
component.connect()
del component
return True
except Exception as e:
print(f"\nA {robot_type} robot is not available.")
print(f"\nA {component_type} is not available.")
if isinstance(e, ModuleNotFoundError):
print(f"\nInstall module '{e.name}'")
elif isinstance(e, SerialException):
print("\nNo physical motors bus detected.")
else:
traceback.print_exc()
return False
@pytest.fixture
def is_camera_available(camera_type):
if camera_type not in available_cameras:
raise ValueError(
f"The camera type '{camera_type}' is not valid. Expected one of these '{available_cameras}"
)
try:
camera = make_camera(camera_type)
camera.connect()
del camera
return True
except Exception as e:
print(f"\nA {camera_type} camera is not available.")
if isinstance(e, ModuleNotFoundError):
print(f"\nInstall module '{e.name}'")
elif isinstance(e, ValueError) and "camera_index" in e.args[0]:
print("\nNo physical device detected.")
elif isinstance(e, ValueError) and "camera_index" in str(e):
print("\nNo physical camera detected.")
else:
traceback.print_exc()
@@ -88,30 +64,19 @@ def is_camera_available(camera_type):
return False
@pytest.fixture
def is_robot_available(robot_type):
return _check_component_availability(robot_type, available_robots, make_robot)
@pytest.fixture
def is_camera_available(camera_type):
return _check_component_availability(camera_type, available_cameras, make_camera)
@pytest.fixture
def is_motor_available(motor_type):
if motor_type not in available_motors:
raise ValueError(
f"The motor type '{motor_type}' is not valid. Expected one of these '{available_motors}"
)
try:
motors_bus = make_motors_bus(motor_type)
motors_bus.connect()
del motors_bus
return True
except Exception as e:
print(f"\nA {motor_type} motor is not available.")
if isinstance(e, ModuleNotFoundError):
print(f"\nInstall module '{e.name}'")
elif isinstance(e, SerialException):
print("\nNo physical motors bus detected.")
else:
traceback.print_exc()
return False
return _check_component_availability(motor_type, available_motors, make_motors_bus)
@pytest.fixture

View File

@@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.common.constants import HF_LEROBOT_HOME
LEROBOT_TEST_DIR = HF_LEROBOT_HOME / "_testing"

View File

@@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from functools import partial
from pathlib import Path

View File

@@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from pathlib import Path

13
tests/fixtures/hub.py vendored
View File

@@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import datasets

View File

@@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch

View File

@@ -1,7 +1,25 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cache
import numpy as np
CAP_V4L2 = 200
CAP_DSHOW = 700
CAP_AVFOUNDATION = 1200
CAP_ANY = -1
CAP_PROP_FPS = 5
CAP_PROP_FRAME_WIDTH = 3
CAP_PROP_FRAME_HEIGHT = 4

View File

@@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mocked classes and functions from dynamixel_sdk to allow for continuous integration
and testing code logic that requires hardware and devices (e.g. robot arms, cameras)

View File

@@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import numpy as np

View File

@@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mocked classes and functions from dynamixel_sdk to allow for continuous integration
and testing code logic that requires hardware and devices (e.g. robot arms, cameras)

View File

@@ -33,12 +33,11 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
# TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
policy=make_policy_config(policy_name, **policy_kwargs),
device="cpu",
)
train_cfg.validate() # Needed for auto-setting some parameters
dataset = make_dataset(train_cfg)
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta, device=train_cfg.device)
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
policy.train()
optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy)

View File

@@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for physical cameras and their mocked versions.
If the physical camera is not connected to the computer, or not working,
@@ -72,8 +85,8 @@ def test_camera(request, camera_type, mock):
camera.connect()
assert camera.is_connected
assert camera.fps is not None
assert camera.width is not None
assert camera.height is not None
assert camera.capture_width is not None
assert camera.capture_height is not None
# Test connecting twice raises an error
with pytest.raises(DeviceAlreadyConnectedError):
@@ -191,3 +204,49 @@ def test_save_images_from_cameras(tmp_path, request, camera_type, mock):
# Small `record_time_s` to speedup unit tests
save_images_from_cameras(tmp_path, record_time_s=0.02, mock=mock)
@pytest.mark.parametrize("camera_type, mock", TEST_CAMERA_TYPES)
@require_camera
def test_camera_rotation(request, camera_type, mock):
config_kwargs = {"camera_type": camera_type, "mock": mock, "width": 640, "height": 480, "fps": 30}
# No rotation.
camera = make_camera(**config_kwargs, rotation=None)
camera.connect()
assert camera.capture_width == 640
assert camera.capture_height == 480
assert camera.width == 640
assert camera.height == 480
no_rot_img = camera.read()
h, w, c = no_rot_img.shape
assert h == 480 and w == 640 and c == 3
camera.disconnect()
# Rotation = 90 (clockwise).
camera = make_camera(**config_kwargs, rotation=90)
camera.connect()
# With a 90° rotation, we expect the metadata dimensions to be swapped.
assert camera.capture_width == 640
assert camera.capture_height == 480
assert camera.width == 480
assert camera.height == 640
import cv2
assert camera.rotation == cv2.ROTATE_90_CLOCKWISE
rot_img = camera.read()
h, w, c = rot_img.shape
assert h == 640 and w == 480 and c == 3
camera.disconnect()
# Rotation = 180.
camera = make_camera(**config_kwargs, rotation=None)
camera.connect()
assert camera.capture_width == 640
assert camera.capture_height == 480
assert camera.width == 640
assert camera.height == 480
no_rot_img = camera.read()
h, w, c = no_rot_img.shape
assert h == 480 and w == 640 and c == 3
camera.disconnect()

View File

@@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for physical robots and their mocked versions.
If the physical robots are not connected to the computer, or not working,
@@ -39,7 +52,7 @@ from lerobot.configs.control import (
from lerobot.configs.policies import PreTrainedConfig
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
from tests.test_robots import make_robot
from tests.utils import DEVICE, TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
@@ -171,7 +184,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
replay(robot, replay_cfg)
policy_cfg = ACTConfig()
policy = make_policy(policy_cfg, ds_meta=dataset.meta, device=DEVICE)
policy = make_policy(policy_cfg, ds_meta=dataset.meta)
out_dir = tmp_path / "logger"
@@ -216,8 +229,6 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
display_cameras=False,
play_sounds=False,
num_image_writer_processes=num_image_writer_processes,
device=DEVICE,
use_amp=False,
)
rec_eval_cfg.policy = PreTrainedConfig.from_pretrained(pretrained_policy_path)

View File

@@ -45,7 +45,7 @@ from lerobot.common.robots.utils import make_robot
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
from tests.utils import DEVICE, require_x86_64_kernel
from tests.utils import require_x86_64_kernel
@pytest.fixture
@@ -349,7 +349,6 @@ def test_factory(env_name, repo_id, policy_name):
dataset=DatasetConfig(repo_id=repo_id, episodes=[0]),
env=make_env_config(env_name),
policy=make_policy_config(policy_name),
device=DEVICE,
)
dataset = make_dataset(cfg)

View File

@@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from itertools import accumulate
import datasets

View File

@@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import queue
import time
from multiprocessing import queues

View File

@@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from pathlib import Path
from typing import Any

View File

@@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker

View File

@@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for physical motors and their mocked versions.
If the physical motors are not connected to the computer, or not working,

View File

@@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch

View File

@@ -143,12 +143,11 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
policy=make_policy_config(policy_name, **policy_kwargs),
env=make_env_config(env_name, **env_kwargs),
device=DEVICE,
)
# Check that we can make the policy object.
dataset = make_dataset(train_cfg)
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta, device=DEVICE)
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
assert isinstance(policy, PreTrainedPolicy)
# Check that we run select_actions and get the appropriate output.
@@ -214,7 +213,6 @@ def test_act_backbone_lr():
# TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001),
device=DEVICE,
)
cfg.validate() # Needed for auto-setting some parameters
@@ -222,7 +220,7 @@ def test_act_backbone_lr():
assert cfg.policy.optimizer_lr_backbone == 0.001
dataset = make_dataset(cfg)
policy = make_policy(cfg.policy, device=DEVICE, ds_meta=dataset.meta)
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
assert len(optimizer.param_groups) == 2
assert optimizer.param_groups[0]["lr"] == cfg.policy.optimizer_lr
@@ -254,10 +252,11 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
}
policy = policy_cls(policy_cfg)
policy.to(policy_cfg.device)
save_dir = tmp_path / f"test_save_and_load_pretrained_{policy_cls.__name__}"
policy.save_pretrained(save_dir)
policy_ = policy_cls.from_pretrained(save_dir, config=policy_cfg)
assert all(torch.equal(p, p_) for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True))
loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg)
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
@pytest.mark.parametrize("insert_temporal_dim", [False, True])
@@ -369,7 +368,7 @@ def test_normalize(insert_temporal_dim):
# was changed to true. For some reason, tests would pass locally, but not in CI. So here we override
# to test with `policy.use_mpc=false`.
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"),
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"),
# ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"),
# TODO(rcadene): the diffusion model was normalizing the image in mean=0.5 std=0.5 which is a hack supposed to
# to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference
# that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass.

View File

@@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import numpy as np

View File

@@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for physical robots and their mocked versions.
If the physical robots are not connected to the computer, or not working,

View File

@@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torch.optim.lr_scheduler import LambdaLR
from lerobot.common.constants import SCHEDULER_STATE

View File

@@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from unittest.mock import Mock, patch

View File

@@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from datasets import Dataset