Add mock robot & teleop
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
from .camera import Camera
|
||||
from .configs import CameraConfig
|
||||
|
||||
__all__ = ["Camera", "CameraConfig"]
|
||||
from .utils import make_cameras_from_configs
|
||||
|
||||
@@ -57,6 +57,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
|
||||
from .viperx import ViperX
|
||||
|
||||
return ViperX(config)
|
||||
elif config.type == "mock_robot":
|
||||
from tests.mocks.mock_robot import MockRobot
|
||||
|
||||
return MockRobot(config)
|
||||
else:
|
||||
raise ValueError(config.type)
|
||||
|
||||
|
||||
@@ -23,5 +23,9 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
|
||||
from .widowx import WidowX
|
||||
|
||||
return WidowX(config)
|
||||
elif config.type == "mock_teleop":
|
||||
from tests.mocks.mock_teleop import MockTeleop
|
||||
|
||||
return MockTeleop(config)
|
||||
else:
|
||||
raise ValueError(config.type)
|
||||
|
||||
110
tests/mocks/mock_robot.py
Normal file
110
tests/mocks/mock_robot.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
from lerobot.common.cameras import CameraConfig, make_cameras_from_configs
|
||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.common.robots import Robot, RobotConfig
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("mock_robot")
|
||||
@dataclass
|
||||
class MockRobotConfig(RobotConfig):
|
||||
n_motors: int = 3
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
random_values: bool = True
|
||||
static_values: list[float] | None = None
|
||||
calibrated: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.n_motors < 1:
|
||||
raise ValueError(self.n_motors)
|
||||
|
||||
if self.random_values and self.static_values is not None:
|
||||
raise ValueError("Choose either random values or static values")
|
||||
|
||||
if self.static_values is not None and len(self.static_values) != self.n_motors:
|
||||
raise ValueError("Specify the same number of static values as motors")
|
||||
|
||||
if len(self.cameras) > 0:
|
||||
raise NotImplementedError # TODO with the cameras refactor
|
||||
|
||||
|
||||
class MockRobot(Robot):
|
||||
"""Mock Robot to be used for testing."""
|
||||
|
||||
config_class = MockRobotConfig
|
||||
name = "mock_robot"
|
||||
|
||||
def __init__(self, config: MockRobotConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self._is_connected = False
|
||||
self._is_calibrated = config.calibrated
|
||||
self.motors = [f"motor_{i + 1}" for i in range(config.n_motors)]
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
|
||||
@property
|
||||
def _motors_ft(self) -> dict[str, type]:
|
||||
return {f"{motor}.pos": float for motor in self.motors}
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._is_connected
|
||||
|
||||
def connect(self) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self._is_connected = True
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self._is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self._is_calibrated = True
|
||||
|
||||
def configure(self) -> None:
|
||||
pass
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.config.random_values:
|
||||
return {f"{motor}.pos": random.uniform(-100, 100) for motor in self.motors}
|
||||
else:
|
||||
return {
|
||||
f"{motor}.pos": val for motor, val in zip(self.motors, self.config.static_values, strict=True)
|
||||
}
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
return action
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self._is_connected = False
|
||||
92
tests/mocks/mock_teleop.py
Normal file
92
tests/mocks/mock_teleop.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.common.teleoperators import Teleoperator, TeleoperatorConfig
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("mock_teleop")
|
||||
@dataclass
|
||||
class MockTeleopConfig(TeleoperatorConfig):
|
||||
n_motors: int = 3
|
||||
random_values: bool = True
|
||||
static_values: list[float] | None = None
|
||||
calibrated: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.n_motors < 1:
|
||||
raise ValueError(self.n_motors)
|
||||
|
||||
if self.random_values and self.static_values is not None:
|
||||
raise ValueError("Choose either random values or static values")
|
||||
|
||||
if self.static_values is not None and len(self.static_values) != self.n_motors:
|
||||
raise ValueError("Specify the same number of static values as motors")
|
||||
|
||||
|
||||
class MockTeleop(Teleoperator):
|
||||
"""Mock Teleoperator to be used for testing."""
|
||||
|
||||
config_class = MockTeleopConfig
|
||||
name = "mock_teleop"
|
||||
|
||||
def __init__(self, config: MockTeleopConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self._is_connected = False
|
||||
self._is_calibrated = config.calibrated
|
||||
self.motors = [f"motor_{i + 1}" for i in range(config.n_motors)]
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return {f"{motor}.pos": float for motor in self.motors}
|
||||
|
||||
@cached_property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {f"{motor}.pos": float for motor in self.motors}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._is_connected
|
||||
|
||||
def connect(self) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self._is_connected = True
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self._is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self._is_calibrated = True
|
||||
|
||||
def configure(self) -> None:
|
||||
pass
|
||||
|
||||
def get_action(self) -> dict[str, Any]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.config.random_values:
|
||||
return {f"{motor}.pos": random.uniform(-100, 100) for motor in self.motors}
|
||||
else:
|
||||
return {
|
||||
f"{motor}.pos": val for motor, val in zip(self.motors, self.config.static_values, strict=True)
|
||||
}
|
||||
|
||||
def send_feedback(self, feedback: dict[str, Any]) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self._is_connected = False
|
||||
Reference in New Issue
Block a user