diff --git a/lerobot/calibrate.py b/lerobot/calibrate.py new file mode 100644 index 00000000..b21ee89d --- /dev/null +++ b/lerobot/calibrate.py @@ -0,0 +1,60 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from dataclasses import asdict, dataclass +from pprint import pformat + +import draccus + +from lerobot.common.cameras import intel, opencv # noqa: F401 +from lerobot.common.robots import ( # noqa: F401 + Robot, + RobotConfig, + koch_follower, + make_robot_from_config, + so100_follower, +) +from lerobot.common.teleoperators import ( # noqa: F401 + Teleoperator, + TeleoperatorConfig, + make_teleoperator_from_config, +) +from lerobot.common.utils.utils import init_logging + +from .common.teleoperators import koch_leader, so100_leader # noqa: F401 + + +@dataclass +class CalibrateConfig: + device: RobotConfig | TeleoperatorConfig + + +@draccus.wrap() +def calibrate(cfg: CalibrateConfig): + init_logging() + logging.info(pformat(asdict(cfg))) + + if isinstance(cfg.device, RobotConfig): + device = make_robot_from_config(cfg.device) + elif isinstance(cfg.device, TeleoperatorConfig): + device = make_teleoperator_from_config(cfg.device) + + device.connect(calibrate=False) + device.calibrate() + device.disconnect() + + +if __name__ == "__main__": + calibrate() diff --git a/lerobot/common/robots/koch_follower/koch_follower.py b/lerobot/common/robots/koch_follower/koch_follower.py index 42beda09..d37f50ac 100644 --- a/lerobot/common/robots/koch_follower/koch_follower.py +++ b/lerobot/common/robots/koch_follower/koch_follower.py @@ -85,7 +85,7 @@ class KochFollower(Robot): # TODO(aliberts): add cam.is_connected for cam in self.cameras return self.arm.is_connected - def connect(self) -> None: + def connect(self, calibrate: bool = True) -> None: """ We assume that at connection time, arm is in a rest position, and torque can be safely disabled to run calibration. @@ -94,7 +94,7 @@ class KochFollower(Robot): raise DeviceAlreadyConnectedError(f"{self} already connected") self.arm.connect() - if not self.is_calibrated: + if not self.is_calibrated and calibrate: self.calibrate() for cam in self.cameras.values(): diff --git a/lerobot/common/robots/lekiwi/lekiwi.py b/lerobot/common/robots/lekiwi/lekiwi.py index 26fa1331..02474898 100644 --- a/lerobot/common/robots/lekiwi/lekiwi.py +++ b/lerobot/common/robots/lekiwi/lekiwi.py @@ -104,12 +104,12 @@ class LeKiwi(Robot): # TODO(aliberts): add cam.is_connected for cam in self.cameras return self.bus.is_connected - def connect(self) -> None: + def connect(self, calibrate: bool = True) -> None: if self.is_connected: raise DeviceAlreadyConnectedError(f"{self} already connected") self.bus.connect() - if not self.is_calibrated: + if not self.is_calibrated and calibrate: self.calibrate() for cam in self.cameras.values(): diff --git a/lerobot/common/robots/moss_follower/moss_follower.py b/lerobot/common/robots/moss_follower/moss_follower.py index e27fb8d7..982e2d47 100644 --- a/lerobot/common/robots/moss_follower/moss_follower.py +++ b/lerobot/common/robots/moss_follower/moss_follower.py @@ -82,7 +82,7 @@ class MossRobot(Robot): # TODO(aliberts): add cam.is_connected for cam in self.cameras return self.arm.is_connected - def connect(self) -> None: + def connect(self, calibrate: bool = True) -> None: """ We assume that at connection time, arm is in a rest position, and torque can be safely disabled to run calibration. @@ -91,7 +91,7 @@ class MossRobot(Robot): raise DeviceAlreadyConnectedError(f"{self} already connected") self.arm.connect() - if not self.is_calibrated: + if not self.is_calibrated and calibrate: self.calibrate() # Connect the cameras diff --git a/lerobot/common/robots/robot.py b/lerobot/common/robots/robot.py index e7b58aa3..bd643c17 100644 --- a/lerobot/common/robots/robot.py +++ b/lerobot/common/robots/robot.py @@ -48,7 +48,7 @@ class Robot(abc.ABC): pass @abc.abstractmethod - def connect(self) -> None: + def connect(self, calibrate: bool = True) -> None: """Connects to the robot.""" pass diff --git a/lerobot/common/robots/so100_follower/so100_follower.py b/lerobot/common/robots/so100_follower/so100_follower.py index 3e063c1b..5f999ae5 100644 --- a/lerobot/common/robots/so100_follower/so100_follower.py +++ b/lerobot/common/robots/so100_follower/so100_follower.py @@ -82,7 +82,7 @@ class SO100Follower(Robot): # TODO(aliberts): add cam.is_connected for cam in self.cameras return self.arm.is_connected - def connect(self) -> None: + def connect(self, calibrate: bool = True) -> None: """ We assume that at connection time, arm is in a rest position, and torque can be safely disabled to run calibration. @@ -91,7 +91,7 @@ class SO100Follower(Robot): raise DeviceAlreadyConnectedError(f"{self} already connected") self.arm.connect() - if not self.is_calibrated: + if not self.is_calibrated and calibrate: self.calibrate() # Connect the cameras diff --git a/lerobot/common/robots/viperx/viperx.py b/lerobot/common/robots/viperx/viperx.py index 7639fd72..3a497113 100644 --- a/lerobot/common/robots/viperx/viperx.py +++ b/lerobot/common/robots/viperx/viperx.py @@ -78,7 +78,7 @@ class ViperX(Robot): # TODO(aliberts): add cam.is_connected for cam in self.cameras return self.arm.is_connected - def connect(self) -> None: + def connect(self, calibrate: bool = True) -> None: """ We assume that at connection time, arm is in a rest position, and torque can be safely disabled to run calibration. @@ -87,7 +87,7 @@ class ViperX(Robot): raise DeviceAlreadyConnectedError(f"{self} already connected") self.arm.connect() - if not self.is_calibrated: + if not self.is_calibrated and calibrate: self.calibrate() for cam in self.cameras.values(): diff --git a/lerobot/common/teleoperators/koch_leader/koch_leader.py b/lerobot/common/teleoperators/koch_leader/koch_leader.py index 410796d1..8f5ac457 100644 --- a/lerobot/common/teleoperators/koch_leader/koch_leader.py +++ b/lerobot/common/teleoperators/koch_leader/koch_leader.py @@ -69,12 +69,12 @@ class KochLeader(Teleoperator): def is_connected(self) -> bool: return self.arm.is_connected - def connect(self) -> None: + def connect(self, calibrate: bool = True) -> None: if self.is_connected: raise DeviceAlreadyConnectedError(f"{self} already connected") self.arm.connect() - if not self.is_calibrated: + if not self.is_calibrated and calibrate: self.calibrate() self.configure() diff --git a/lerobot/common/teleoperators/so100_leader/so100_leader.py b/lerobot/common/teleoperators/so100_leader/so100_leader.py index 4ca982c1..a063edd1 100644 --- a/lerobot/common/teleoperators/so100_leader/so100_leader.py +++ b/lerobot/common/teleoperators/so100_leader/so100_leader.py @@ -66,12 +66,12 @@ class SO100Leader(Teleoperator): def is_connected(self) -> bool: return self.arm.is_connected - def connect(self) -> None: + def connect(self, calibrate: bool = True) -> None: if self.is_connected: raise DeviceAlreadyConnectedError(f"{self} already connected") self.arm.connect() - if not self.is_calibrated: + if not self.is_calibrated and calibrate: self.calibrate() self.configure() diff --git a/lerobot/common/teleoperators/teleoperator.py b/lerobot/common/teleoperators/teleoperator.py index ee8fd5a1..c09f76ad 100644 --- a/lerobot/common/teleoperators/teleoperator.py +++ b/lerobot/common/teleoperators/teleoperator.py @@ -46,7 +46,7 @@ class Teleoperator(abc.ABC): pass @abc.abstractmethod - def connect(self) -> None: + def connect(self, calibrate: bool = True) -> None: """Connects to the teleoperator.""" pass diff --git a/lerobot/common/teleoperators/widowx/widowx.py b/lerobot/common/teleoperators/widowx/widowx.py index 0cec46f8..4b09f6d0 100644 --- a/lerobot/common/teleoperators/widowx/widowx.py +++ b/lerobot/common/teleoperators/widowx/widowx.py @@ -69,12 +69,12 @@ class WidowX(Teleoperator): def is_connected(self) -> bool: return self.arm.is_connected - def connect(self): + def connect(self, calibrate: bool = True): if self.is_connected: raise DeviceAlreadyConnectedError(f"{self} already connected") self.arm.connect() - if not self.is_calibrated: + if not self.is_calibrated and calibrate: self.calibrate() self.configure() diff --git a/tests/mocks/mock_robot.py b/tests/mocks/mock_robot.py index 96d2365f..40d8fbde 100644 --- a/tests/mocks/mock_robot.py +++ b/tests/mocks/mock_robot.py @@ -67,11 +67,13 @@ class MockRobot(Robot): def is_connected(self) -> bool: return self._is_connected - def connect(self) -> None: + def connect(self, calibrate: bool = True) -> None: if self.is_connected: raise DeviceAlreadyConnectedError(f"{self} already connected") self._is_connected = True + if calibrate: + self.calibrate() @property def is_calibrated(self) -> bool: diff --git a/tests/mocks/mock_teleop.py b/tests/mocks/mock_teleop.py index d8038096..a7f5cad3 100644 --- a/tests/mocks/mock_teleop.py +++ b/tests/mocks/mock_teleop.py @@ -51,11 +51,13 @@ class MockTeleop(Teleoperator): def is_connected(self) -> bool: return self._is_connected - def connect(self) -> None: + def connect(self, calibrate: bool = True) -> None: if self.is_connected: raise DeviceAlreadyConnectedError(f"{self} already connected") self._is_connected = True + if calibrate: + self.calibrate() @property def is_calibrated(self) -> bool: diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index 06c4b6f8..2b12e742 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -1,5 +1,6 @@ import time +from lerobot.calibrate import CalibrateConfig, calibrate from lerobot.record import DatasetRecordConfig, RecordConfig, record from lerobot.replay import DatasetReplayConfig, ReplayConfig, replay from lerobot.teleoperate import TeleoperateConfig, teleoperate @@ -8,6 +9,12 @@ from tests.mocks.mock_robot import MockRobotConfig from tests.mocks.mock_teleop import MockTeleopConfig +def test_calibrate(): + robot_cfg = MockRobotConfig() + cfg = CalibrateConfig(device=robot_cfg) + calibrate(cfg) + + def test_teleoperate(): robot_cfg = MockRobotConfig() teleop_cfg = MockTeleopConfig()