From 2c86fea78aa3a3177411a9758a0a703baac74ea8 Mon Sep 17 00:00:00 2001 From: Simon Alibert <75076266+aliberts@users.noreply.github.com> Date: Tue, 8 Apr 2025 12:44:09 +0200 Subject: [PATCH 01/23] Switch typos pre-commit to mirror (#953) --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b5e09719c..4df93a36a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -36,8 +36,8 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - - repo: https://github.com/crate-ci/typos - rev: v1 + - repo: https://github.com/adhtruong/mirrors-typos + rev: v1.31.1 hooks: - id: typos args: [--force-exclude] From 034171a89abd18e9a262c04f4f40c0b0eb04a5bd Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Wed, 9 Apr 2025 10:26:30 +0200 Subject: [PATCH 02/23] Add Feetech protocol version --- lerobot/common/motors/feetech/feetech.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lerobot/common/motors/feetech/feetech.py b/lerobot/common/motors/feetech/feetech.py index 064927b0c..35a8a715a 100644 --- a/lerobot/common/motors/feetech/feetech.py +++ b/lerobot/common/motors/feetech/feetech.py @@ -31,7 +31,7 @@ from .tables import ( SCAN_BAUDRATES, ) -PROTOCOL_VERSION = 0 +DEFAULT_PROTOCOL_VERSION = 0 BAUDRATE = 1_000_000 DEFAULT_TIMEOUT_MS = 1000 @@ -97,6 +97,7 @@ class FeetechMotorsBus(MotorsBus): port: str, motors: dict[str, Motor], calibration: dict[str, MotorCalibration] | None = None, + protocol_version: int = DEFAULT_PROTOCOL_VERSION, ): super().__init__(port, motors, calibration) import scservo_sdk as scs @@ -106,7 +107,7 @@ class FeetechMotorsBus(MotorsBus): self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( self.port_handler, scs.PortHandler ) - self.packet_handler = scs.PacketHandler(PROTOCOL_VERSION) + self.packet_handler = scs.PacketHandler(protocol_version) self.sync_reader = scs.GroupSyncRead(self.port_handler, self.packet_handler, 0, 0) self.sync_writer = scs.GroupSyncWrite(self.port_handler, self.packet_handler, 0, 0) self._comm_success = scs.COMM_SUCCESS From 4041f57943ae0383e0995825cb0e49e951fe6fc1 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 9 Apr 2025 17:33:01 +0200 Subject: [PATCH 03/23] feat(visualization): replace cv2 GUI with Rerun (and solves ffmpeg versioning issues) (#903) --- README.md | 8 +-- benchmarks/video/capture_camera_feed.py | 32 ++++++++---- examples/10_use_so100.md | 13 ++++- examples/11_use_lekiwi.md | 22 ++++++-- examples/11_use_moss.md | 13 ++++- examples/2_evaluate_pretrained_policy.py | 2 +- examples/7_get_started_with_real_robot.md | 13 +++-- examples/8_use_stretch.md | 13 +++-- examples/9_use_aloha.md | 12 ++++- lerobot/common/policies/pi0/modeling_pi0.py | 2 +- .../common/robot_devices/control_configs.py | 9 +++- lerobot/common/robot_devices/control_utils.py | 32 ++++++------ lerobot/scripts/control_robot.py | 52 +++++++++++++++++-- pyproject.toml | 4 +- tests/robots/test_control_robot.py | 13 +++-- 15 files changed, 175 insertions(+), 65 deletions(-) diff --git a/README.md b/README.md index 4483940d4..20ebeee87 100644 --- a/README.md +++ b/README.md @@ -98,14 +98,14 @@ conda create -y -n lerobot python=3.10 conda activate lerobot ``` -When using `miniconda`, if you don't have `ffmpeg` in your environment: +When using `miniconda`, install `ffmpeg` in your environment: ```bash -conda install ffmpeg +conda install ffmpeg -c conda-forge ``` Install 🤗 LeRobot: ```bash -pip install --no-binary=av -e . +pip install -e . ``` > **NOTE:** If you encounter build errors, you may need to install additional dependencies (`cmake`, `build-essential`, and `ffmpeg libs`). On Linux, run: @@ -118,7 +118,7 @@ For simulations, 🤗 LeRobot comes with gymnasium environments that can be inst For instance, to install 🤗 LeRobot with aloha and pusht, use: ```bash -pip install --no-binary=av -e ".[aloha, pusht]" +pip install -e ".[aloha, pusht]" ``` To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with diff --git a/benchmarks/video/capture_camera_feed.py b/benchmarks/video/capture_camera_feed.py index 3b4c356a8..ce248f20b 100644 --- a/benchmarks/video/capture_camera_feed.py +++ b/benchmarks/video/capture_camera_feed.py @@ -17,12 +17,21 @@ import argparse import datetime as dt +import os +import time from pathlib import Path import cv2 +import rerun as rr + +# see https://rerun.io/docs/howto/visualization/limit-ram +RERUN_MEMORY_LIMIT = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "5%") -def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height: int): +def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height: int, duration: int): + rr.init("lerobot_capture_camera_feed") + rr.spawn(memory_limit=RERUN_MEMORY_LIMIT) + now = dt.datetime.now() capture_dir = output_dir / f"{now:%Y-%m-%d}" / f"{now:%H-%M-%S}" if not capture_dir.exists(): @@ -39,24 +48,21 @@ def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height) frame_index = 0 - while True: + start_time = time.time() + while time.time() - start_time < duration: ret, frame = cap.read() if not ret: print("Error: Could not read frame.") break - - cv2.imshow("Video Stream", frame) + rr.log("video/stream", rr.Image(frame.numpy()), static=True) cv2.imwrite(str(capture_dir / f"frame_{frame_index:06d}.png"), frame) frame_index += 1 - # Break the loop on 'q' key press - if cv2.waitKey(1) & 0xFF == ord("q"): - break - - # Release the capture and destroy all windows + # Release the capture cap.release() - cv2.destroyAllWindows() + + # TODO(Steven): Add a graceful shutdown via a close() method for the Viewer context, though not currently supported in the Rerun API. if __name__ == "__main__": @@ -86,5 +92,11 @@ if __name__ == "__main__": default=720, help="Height of the captured images.", ) + parser.add_argument( + "--duration", + type=int, + default=20, + help="Duration in seconds for which the video stream should be captured.", + ) args = parser.parse_args() display_and_save_video_stream(**vars(args)) diff --git a/examples/10_use_so100.md b/examples/10_use_so100.md index 8fb6d3b55..9dbe974c1 100644 --- a/examples/10_use_so100.md +++ b/examples/10_use_so100.md @@ -57,9 +57,15 @@ conda activate lerobot git clone https://github.com/huggingface/lerobot.git ~/lerobot ``` -#### 5. Install LeRobot with dependencies for the feetech motors: +#### 5. Install ffmpeg in your environment: +When using `miniconda`, install `ffmpeg` in your environment: ```bash -cd ~/lerobot && pip install --no-binary=av -e ".[feetech]" +conda install ffmpeg -c conda-forge +``` + +#### 6. Install LeRobot with dependencies for the feetech motors: +```bash +cd ~/lerobot && pip install -e ".[feetech]" ``` Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms :robot:. @@ -491,6 +497,9 @@ python lerobot/scripts/control_robot.py \ #### a. Teleop with displaying cameras Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset. + +> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`. + ```bash python lerobot/scripts/control_robot.py \ --robot.type=so100 \ diff --git a/examples/11_use_lekiwi.md b/examples/11_use_lekiwi.md index 215419e19..1be7cbc4a 100644 --- a/examples/11_use_lekiwi.md +++ b/examples/11_use_lekiwi.md @@ -67,9 +67,15 @@ conda activate lerobot git clone https://github.com/huggingface/lerobot.git ~/lerobot ``` -#### 5. Install LeRobot with dependencies for the feetech motors: +#### 5. Install ffmpeg in your environment: +When using `miniconda`, install `ffmpeg` in your environment: ```bash -cd ~/lerobot && pip install --no-binary=av -e ".[feetech]" +conda install ffmpeg -c conda-forge +``` + +#### 6. Install LeRobot with dependencies for the feetech motors: +```bash +cd ~/lerobot && pip install -e ".[feetech]" ``` ## C. Install LeRobot on laptop @@ -108,9 +114,15 @@ conda activate lerobot git clone https://github.com/huggingface/lerobot.git ~/lerobot ``` -#### 5. Install LeRobot with dependencies for the feetech motors: +#### 5. Install ffmpeg in your environment: +When using `miniconda`, install `ffmpeg` in your environment: ```bash -cd ~/lerobot && pip install --no-binary=av -e ".[feetech]" +conda install ffmpeg -c conda-forge +``` + +#### 6. Install LeRobot with dependencies for the feetech motors: +```bash +cd ~/lerobot && pip install -e ".[feetech]" ``` Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms and Mobile base :robot:. @@ -412,6 +424,8 @@ python lerobot/scripts/control_robot.py \ --control.fps=30 ``` +> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`. For the `--control.type=remote_robot` you will also need to set `--control.viewer_ip` and `--control.viewer_port` + You should see on your laptop something like this: ```[INFO] Connected to remote robot at tcp://172.17.133.91:5555 and video stream at tcp://172.17.133.91:5556.``` Now you can move the leader arm and use the keyboard (w,a,s,d) to drive forward, left, backwards, right. And use (z,x) to turn left or turn right. You can use (r,f) to increase and decrease the speed of the mobile robot. There are three speed modes, see the table below: | Speed Mode | Linear Speed (m/s) | Rotation Speed (deg/s) | | ---------- | ------------------ | ---------------------- | diff --git a/examples/11_use_moss.md b/examples/11_use_moss.md index 7b1be232c..1b6f23b9a 100644 --- a/examples/11_use_moss.md +++ b/examples/11_use_moss.md @@ -31,9 +31,15 @@ conda create -y -n lerobot python=3.10 && conda activate lerobot git clone https://github.com/huggingface/lerobot.git ~/lerobot ``` -5. Install LeRobot with dependencies for the feetech motors: +5. Install ffmpeg in your environment: +When using `miniconda`, install `ffmpeg` in your environment: ```bash -cd ~/lerobot && pip install --no-binary=av -e ".[feetech]" +conda install ffmpeg -c conda-forge +``` + +6. Install LeRobot with dependencies for the feetech motors: +```bash +cd ~/lerobot && pip install -e ".[feetech]" ``` ## Configure the motors @@ -212,6 +218,9 @@ python lerobot/scripts/control_robot.py \ **Teleop with displaying cameras** Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset. + +> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`. + ```bash python lerobot/scripts/control_robot.py \ --robot.type=moss \ diff --git a/examples/2_evaluate_pretrained_policy.py b/examples/2_evaluate_pretrained_policy.py index 24b5ea2c8..edbbad389 100644 --- a/examples/2_evaluate_pretrained_policy.py +++ b/examples/2_evaluate_pretrained_policy.py @@ -18,7 +18,7 @@ training outputs directory. In the latter case, you might want to run examples/3 It requires the installation of the 'gym_pusht' simulation environment. Install it by running: ```bash -pip install --no-binary=av -e ".[pusht]"` +pip install -e ".[pusht]"` ``` """ diff --git a/examples/7_get_started_with_real_robot.md b/examples/7_get_started_with_real_robot.md index 5b12e903f..3562c0e66 100644 --- a/examples/7_get_started_with_real_robot.md +++ b/examples/7_get_started_with_real_robot.md @@ -33,7 +33,7 @@ First, install the additional dependencies required for robots built with dynami Using `pip`: ```bash -pip install --no-binary=av -e ".[dynamixel]" +pip install -e ".[dynamixel]" ``` Using `poetry`: @@ -55,6 +55,9 @@ Finally, connect both arms to your computer via USB. Note that the USB doesn't p Now you are ready to configure your motors for the first time, as detailed in the sections below. In the upcoming sections, you'll learn about our classes and functions by running some python code in an interactive session, or by copy-pasting it in a python file. If you have already configured your motors the first time, you can streamline the process by directly running the teleoperate script (which is detailed further in the tutorial): + +> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`. + ```bash python lerobot/scripts/control_robot.py \ --robot.type=koch \ @@ -828,10 +831,10 @@ It contains: Troubleshooting: - On Linux, if you encounter any issue during video encoding with `ffmpeg: unknown encoder libsvtav1`, you can: - - install with conda-forge by running `conda install -c conda-forge ffmpeg` (it should be compiled with `libsvtav1`), - - or, install [Homebrew](https://brew.sh) and run `brew install ffmpeg` (it should be compiled with `libsvtav1`), - - or, install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), - - and, make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`. + - install with conda-forge by running `conda install -c conda-forge ffmpeg` (it should be compiled with `libsvtav1`), +> **NOTE:** This usually installs `ffmpeg 7.X` for your platform (check the version installed with `ffmpeg -encoders | grep libsvtav1`). If it isn't `ffmpeg 7.X` or lacks `libsvtav1` support, you can explicitly install `ffmpeg 7.X` using: `conda install ffmpeg=7.1.1 -c conda-forge` + - or, install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), + - and, make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`. - On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux). At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. https://huggingface.co/datasets/cadene/koch_test) that you can obtain by running: diff --git a/examples/8_use_stretch.md b/examples/8_use_stretch.md index d02e7ef39..a7a7dde17 100644 --- a/examples/8_use_stretch.md +++ b/examples/8_use_stretch.md @@ -43,14 +43,19 @@ conda create -y -n lerobot python=3.10 && conda activate lerobot git clone https://github.com/huggingface/lerobot.git ~/lerobot ``` -6. Install LeRobot with stretch dependencies: +6. When using `miniconda`, install `ffmpeg` in your environment: ```bash -cd ~/lerobot && pip install --no-binary=av -e ".[stretch]" +conda install ffmpeg -c conda-forge +``` + +7. Install LeRobot with stretch dependencies: +```bash +cd ~/lerobot && pip install -e ".[stretch]" ``` > **Note:** If you get this message, you can ignore it: `ERROR: pip's dependency resolver does not currently take into account all the packages that are installed.` -7. Run a [system check](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#system-check) to make sure your robot is ready: +8. Run a [system check](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#system-check) to make sure your robot is ready: ```bash stretch_system_check.py ``` @@ -97,6 +102,8 @@ This is equivalent to running `stretch_robot_home.py` Before trying teleoperation, you need activate the gamepad controller by pressing the middle button. For more info, see Stretch's [doc](https://docs.hello-robot.com/0.3/getting_started/hello_robot/#gamepad-teleoperation). Now try out teleoperation (see above documentation to learn about the gamepad controls): + +> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`. ```bash python lerobot/scripts/control_robot.py \ --robot.type=stretch \ diff --git a/examples/9_use_aloha.md b/examples/9_use_aloha.md index 1f7aee3c8..77cff1611 100644 --- a/examples/9_use_aloha.md +++ b/examples/9_use_aloha.md @@ -30,9 +30,14 @@ conda create -y -n lerobot python=3.10 && conda activate lerobot git clone https://github.com/huggingface/lerobot.git ~/lerobot ``` -5. Install LeRobot with dependencies for the Aloha motors (dynamixel) and cameras (intelrealsense): +5. When using `miniconda`, install `ffmpeg` in your environment: ```bash -cd ~/lerobot && pip install --no-binary=av -e ".[dynamixel, intelrealsense]" +conda install ffmpeg -c conda-forge +``` + +6. Install LeRobot with dependencies for the Aloha motors (dynamixel) and cameras (intelrealsense): +```bash +cd ~/lerobot && pip install -e ".[dynamixel, intelrealsense]" ``` ## Teleoperate @@ -43,6 +48,9 @@ Teleoperation consists in manually operating the leader arms to move the followe 2. Our code assumes that your robot has been assembled following Trossen Robotics instructions. This allows us to skip calibration, as we use the pre-defined calibration files in `.cache/calibration/aloha_default`. If you replace a motor, make sure you follow the exact instructions from Trossen Robotics. By running the following code, you can start your first **SAFE** teleoperation: + +> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`. + ```bash python lerobot/scripts/control_robot.py \ --robot.type=aloha \ diff --git a/lerobot/common/policies/pi0/modeling_pi0.py b/lerobot/common/policies/pi0/modeling_pi0.py index 4462f162b..7599fa635 100644 --- a/lerobot/common/policies/pi0/modeling_pi0.py +++ b/lerobot/common/policies/pi0/modeling_pi0.py @@ -24,7 +24,7 @@ Designed by Physical Intelligence. Ported from Jax by Hugging Face. Install pi0 extra dependencies: ```bash -pip install --no-binary=av -e ".[pi0]" +pip install -e ".[pi0]" ``` Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`): diff --git a/lerobot/common/robot_devices/control_configs.py b/lerobot/common/robot_devices/control_configs.py index 0ecd8683a..cb558c716 100644 --- a/lerobot/common/robot_devices/control_configs.py +++ b/lerobot/common/robot_devices/control_configs.py @@ -41,7 +41,7 @@ class TeleoperateControlConfig(ControlConfig): fps: int | None = None teleop_time_s: float | None = None # Display all cameras on screen - display_cameras: bool = True + display_data: bool = False @ControlConfig.register_subclass("record") @@ -82,7 +82,7 @@ class RecordControlConfig(ControlConfig): # Not enough threads might cause low camera fps. num_image_writer_threads_per_camera: int = 4 # Display all cameras on screen - display_cameras: bool = True + display_data: bool = False # Use vocal synthesis to read events. play_sounds: bool = True # Resume recording on an existing dataset. @@ -116,6 +116,11 @@ class ReplayControlConfig(ControlConfig): @dataclass class RemoteRobotConfig(ControlConfig): log_interval: int = 100 + # Display all cameras on screen + display_data: bool = False + # Rerun configuration for remote robot (https://ref.rerun.io/docs/python/0.22.1/common/initialization_functions/#rerun.connect_tcp) + viewer_ip: str | None = None + viewer_port: str | None = None @dataclass diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 78a8c6a6d..4e42a9896 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -24,7 +24,7 @@ from contextlib import nullcontext from copy import copy from functools import cache -import cv2 +import rerun as rr import torch from deepdiff import DeepDiff from termcolor import colored @@ -174,13 +174,13 @@ def warmup_record( events, enable_teleoperation, warmup_time_s, - display_cameras, + display_data, fps, ): control_loop( robot=robot, control_time_s=warmup_time_s, - display_cameras=display_cameras, + display_data=display_data, events=events, fps=fps, teleoperate=enable_teleoperation, @@ -192,7 +192,7 @@ def record_episode( dataset, events, episode_time_s, - display_cameras, + display_data, policy, fps, single_task, @@ -200,7 +200,7 @@ def record_episode( control_loop( robot=robot, control_time_s=episode_time_s, - display_cameras=display_cameras, + display_data=display_data, dataset=dataset, events=events, policy=policy, @@ -215,7 +215,7 @@ def control_loop( robot, control_time_s=None, teleoperate=False, - display_cameras=False, + display_data=False, dataset: LeRobotDataset | None = None, events=None, policy: PreTrainedPolicy = None, @@ -264,11 +264,15 @@ def control_loop( frame = {**observation, **action, "task": single_task} dataset.add_frame(frame) - if display_cameras and not is_headless(): + # TODO(Steven): This should be more general (for RemoteRobot instead of checking the name, but anyways it will change soon) + if (display_data and not is_headless()) or (display_data and robot.robot_type.startswith("lekiwi")): + for k, v in action.items(): + for i, vv in enumerate(v): + rr.log(f"sent_{k}_{i}", rr.Scalar(vv.numpy())) + image_keys = [key for key in observation if "image" in key] for key in image_keys: - cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) - cv2.waitKey(1) + rr.log(key, rr.Image(observation[key].numpy()), static=True) if fps is not None: dt_s = time.perf_counter() - start_loop_t @@ -297,15 +301,11 @@ def reset_environment(robot, events, reset_time_s, fps): ) -def stop_recording(robot, listener, display_cameras): +def stop_recording(robot, listener, display_data): robot.disconnect() - if not is_headless(): - if listener is not None: - listener.stop() - - if display_cameras: - cv2.destroyAllWindows() + if not is_headless() and listener is not None: + listener.stop() def sanity_check_dataset_name(repo_id, policy_cfg): diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 3c3c43f91..3daea98d3 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -135,15 +135,19 @@ python lerobot/scripts/control_robot.py \ """ import logging +import os import time from dataclasses import asdict from pprint import pformat +import rerun as rr + # from safetensors.torch import load_file, save_file from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.policies.factory import make_policy from lerobot.common.robot_devices.control_configs import ( CalibrateControlConfig, + ControlConfig, ControlPipelineConfig, RecordControlConfig, RemoteRobotConfig, @@ -153,6 +157,7 @@ from lerobot.common.robot_devices.control_configs import ( from lerobot.common.robot_devices.control_utils import ( control_loop, init_keyboard_listener, + is_headless, log_control_info, record_episode, reset_environment, @@ -232,7 +237,7 @@ def teleoperate(robot: Robot, cfg: TeleoperateControlConfig): control_time_s=cfg.teleop_time_s, fps=cfg.fps, teleoperate=True, - display_cameras=cfg.display_cameras, + display_data=cfg.display_data, ) @@ -280,7 +285,7 @@ def record( # 3. place the cameras windows on screen enable_teleoperation = policy is None log_say("Warmup record", cfg.play_sounds) - warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_cameras, cfg.fps) + warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_data, cfg.fps) if has_method(robot, "teleop_safety_stop"): robot.teleop_safety_stop() @@ -296,7 +301,7 @@ def record( dataset=dataset, events=events, episode_time_s=cfg.episode_time_s, - display_cameras=cfg.display_cameras, + display_data=cfg.display_data, policy=policy, fps=cfg.fps, single_task=cfg.single_task, @@ -326,7 +331,7 @@ def record( break log_say("Stop recording", cfg.play_sounds, blocking=True) - stop_recording(robot, listener, cfg.display_cameras) + stop_recording(robot, listener, cfg.display_data) if cfg.push_to_hub: dataset.push_to_hub(tags=cfg.tags, private=cfg.private) @@ -363,6 +368,40 @@ def replay( log_control_info(robot, dt_s, fps=cfg.fps) +def _init_rerun(control_config: ControlConfig, session_name: str = "lerobot_control_loop") -> None: + """Initializes the Rerun SDK for visualizing the control loop. + + Args: + control_config: Configuration determining data display and robot type. + session_name: Rerun session name. Defaults to "lerobot_control_loop". + + Raises: + ValueError: If viewer IP is missing for non-remote configurations with display enabled. + """ + if (control_config.display_data and not is_headless()) or ( + control_config.display_data and isinstance(control_config, RemoteRobotConfig) + ): + # Configure Rerun flush batch size default to 8KB if not set + batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000") + os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size + + # Initialize Rerun based on configuration + rr.init(session_name) + if isinstance(control_config, RemoteRobotConfig): + viewer_ip = control_config.viewer_ip + viewer_port = control_config.viewer_port + if not viewer_ip or not viewer_port: + raise ValueError( + "Viewer IP & Port are required for remote config. Set via config file/CLI or disable control_config.display_data." + ) + logging.info(f"Connecting to viewer at {viewer_ip}:{viewer_port}") + rr.connect_tcp(f"{viewer_ip}:{viewer_port}") + else: + # Get memory limit for rerun viewer parameters + memory_limit = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "10%") + rr.spawn(memory_limit=memory_limit) + + @parser.wrap() def control_robot(cfg: ControlPipelineConfig): init_logging() @@ -370,17 +409,22 @@ def control_robot(cfg: ControlPipelineConfig): robot = make_robot_from_config(cfg.robot) + # TODO(Steven): Blueprint for fixed window size + if isinstance(cfg.control, CalibrateControlConfig): calibrate(robot, cfg.control) elif isinstance(cfg.control, TeleoperateControlConfig): + _init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_teleop") teleoperate(robot, cfg.control) elif isinstance(cfg.control, RecordControlConfig): + _init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_record") record(robot, cfg.control) elif isinstance(cfg.control, ReplayControlConfig): replay(robot, cfg.control) elif isinstance(cfg.control, RemoteRobotConfig): from lerobot.common.robot_devices.robots.lekiwi_remote import run_lekiwi + _init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_remote") run_lekiwi(cfg.robot) if robot.is_connected: diff --git a/pyproject.toml b/pyproject.toml index 6b9b6802c..4b858634d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,9 +60,9 @@ dependencies = [ "jsonlines>=4.0.0", "numba>=0.59.0", "omegaconf>=2.3.0", - "opencv-python>=4.9.0", + "opencv-python-headless>=4.9.0", "packaging>=24.2", - "av>=12.0.5,<13.0.0", + "av>=12.0.5", "pymunk>=6.6.0", "pynput>=1.7.7", "pyzmq>=26.2.1", diff --git a/tests/robots/test_control_robot.py b/tests/robots/test_control_robot.py index 61d1caad7..3f618fc27 100644 --- a/tests/robots/test_control_robot.py +++ b/tests/robots/test_control_robot.py @@ -172,8 +172,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock): push_to_hub=False, # TODO(rcadene, aliberts): test video=True video=False, - # TODO(rcadene): display cameras through cv2 sometimes crashes on mac - display_cameras=False, + display_data=False, play_sounds=False, ) dataset = record(robot, rec_cfg) @@ -226,7 +225,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock): num_episodes=2, push_to_hub=False, video=False, - display_cameras=False, + display_data=False, play_sounds=False, num_image_writer_processes=num_image_writer_processes, ) @@ -273,7 +272,7 @@ def test_resume_record(tmp_path, request, robot_type, mock): episode_time_s=1, push_to_hub=False, video=False, - display_cameras=False, + display_data=False, play_sounds=False, num_episodes=1, ) @@ -330,7 +329,7 @@ def test_record_with_event_rerecord_episode(tmp_path, request, robot_type, mock) num_episodes=1, push_to_hub=False, video=False, - display_cameras=False, + display_data=False, play_sounds=False, ) dataset = record(robot, rec_cfg) @@ -380,7 +379,7 @@ def test_record_with_event_exit_early(tmp_path, request, robot_type, mock): num_episodes=1, push_to_hub=False, video=False, - display_cameras=False, + display_data=False, play_sounds=False, ) @@ -433,7 +432,7 @@ def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, n num_episodes=2, push_to_hub=False, video=False, - display_cameras=False, + display_data=False, play_sounds=False, num_image_writer_processes=num_image_writer_processes, ) From 5322417c0302b517b94d938e12b0e10405e6b649 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 9 Apr 2025 17:44:32 +0200 Subject: [PATCH 04/23] fix(examples): removes extra backtick (#948) --- examples/2_evaluate_pretrained_policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/2_evaluate_pretrained_policy.py b/examples/2_evaluate_pretrained_policy.py index edbbad389..686069589 100644 --- a/examples/2_evaluate_pretrained_policy.py +++ b/examples/2_evaluate_pretrained_policy.py @@ -18,7 +18,7 @@ training outputs directory. In the latter case, you might want to run examples/3 It requires the installation of the 'gym_pusht' simulation environment. Install it by running: ```bash -pip install -e ".[pusht]"` +pip install -e ".[pusht]" ``` """ From 42a87e7211bd7c5664437e60f9e263166e32bc44 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Thu, 10 Apr 2025 00:35:14 +0200 Subject: [PATCH 05/23] Implement read --- lerobot/common/motors/dynamixel/dynamixel.py | 3 + lerobot/common/motors/feetech/feetech.py | 7 +++ lerobot/common/motors/motors_bus.py | 63 ++++++++++++++++++++ 3 files changed, 73 insertions(+) diff --git a/lerobot/common/motors/dynamixel/dynamixel.py b/lerobot/common/motors/dynamixel/dynamixel.py index cb2d28294..dc16ba6ee 100644 --- a/lerobot/common/motors/dynamixel/dynamixel.py +++ b/lerobot/common/motors/dynamixel/dynamixel.py @@ -116,6 +116,9 @@ class DynamixelMotorsBus(MotorsBus): self._comm_success = dxl.COMM_SUCCESS self._no_error = 0x00 + def _assert_protocol_is_compatible(self, instruction_name: str) -> None: + pass + def configure_motors(self) -> None: # By default, Dynamixel motors have a 500µs delay response time (corresponding to a value of 250 on # the 'Return_Delay_Time' address). We ensure this is reduced to the minimum of 2µs (value of 0). diff --git a/lerobot/common/motors/feetech/feetech.py b/lerobot/common/motors/feetech/feetech.py index 35a8a715a..89c8ac9f7 100644 --- a/lerobot/common/motors/feetech/feetech.py +++ b/lerobot/common/motors/feetech/feetech.py @@ -102,6 +102,7 @@ class FeetechMotorsBus(MotorsBus): super().__init__(port, motors, calibration) import scservo_sdk as scs + self.protocol_version = protocol_version self.port_handler = scs.PortHandler(self.port) # HACK: monkeypatch self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( @@ -113,6 +114,12 @@ class FeetechMotorsBus(MotorsBus): self._comm_success = scs.COMM_SUCCESS self._no_error = 0x00 + def _assert_protocol_is_compatible(self, instruction_name: str) -> None: + if instruction_name == "sync_read" and self.protocol_version == 1: + raise NotImplementedError( + "'Sync Read' is not available with Feetech motors using Protocol 1. Use 'Read' instead." + ) + def configure_motors(self) -> None: # By default, Feetech motors have a 500µs delay response time (corresponding to a value of 250 on the # 'Return_Delay' address). We ensure this is reduced to the minimum of 2µs (value of 0). diff --git a/lerobot/common/motors/motors_bus.py b/lerobot/common/motors/motors_bus.py index 7b568a2f0..9f3fcdb2c 100644 --- a/lerobot/common/motors/motors_bus.py +++ b/lerobot/common/motors/motors_bus.py @@ -393,6 +393,10 @@ class MotorsBus(abc.ABC): "was found instead for that id." ) + @abc.abstractmethod + def _assert_protocol_is_compatible(self, instruction_name: str) -> None: + pass + @property def is_connected(self) -> bool: return self.port_handler.is_open @@ -723,6 +727,63 @@ class MotorsBus(abc.ABC): ) -> dict[int, list[int, str]] | None: pass + def read( + self, + data_name: str, + motor: str, + *, + normalize: bool = True, + num_retry: int = 0, + ) -> Value: + if not self.is_connected: + raise DeviceNotConnectedError( + f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." + ) + + id_ = self.motors[motor].id + model = self.motors[motor].model + addr, n_bytes = get_address(self.model_ctrl_table, model, data_name) + + value, comm, error = self._read(addr, n_bytes, id_, num_retry=num_retry) + if not self._is_comm_success(comm): + raise ConnectionError( + f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries." + f"{self.packet_handler.getTxRxResult(comm)}" + ) + elif self._is_error(error): + raise RuntimeError( + f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries." + f"\n{self.packet_handler.getRxPacketError(error)}" + ) + + id_value = self._decode_sign(data_name, {id_: value}) + + if normalize and data_name in self.normalized_data: + id_value = self._normalize(data_name, id_value) + + return id_value[id_] + + def _read(self, addr: int, n_bytes: int, motor_id: int, num_retry: int = 0) -> tuple[int, int]: + if n_bytes == 1: + read_fn = self.packet_handler.read1ByteTxRx + elif n_bytes == 2: + read_fn = self.packet_handler.read2ByteTxRx + elif n_bytes == 4: + read_fn = self.packet_handler.read4ByteTxRx + else: + raise ValueError(n_bytes) + + for n_try in range(1 + num_retry): + value, comm, error = read_fn(self.port_handler, motor_id, addr) + if self._is_comm_success(comm): + break + logger.debug( + f"Failed to read @{addr=} ({n_bytes=}) on {motor_id=} ({n_try=}): " + + self.packet_handler.getTxRxResult(comm) + ) + + return value, comm, error + def sync_read( self, data_name: str, @@ -736,6 +797,8 @@ class MotorsBus(abc.ABC): f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." ) + self._assert_protocol_is_compatible("sync_read") + names = self._get_names_list(motors) ids = [self.motors[name].id for name in names] models = [self.motors[name].model for name in names] From 443fed216ca9961149e4cf5296222c22b25f5817 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Thu, 10 Apr 2025 00:49:03 +0200 Subject: [PATCH 06/23] Use constants from sdks --- tests/mocks/mock_dynamixel.py | 54 +++++++++++++++++------------------ tests/mocks/mock_feetech.py | 26 ++++++++--------- 2 files changed, 40 insertions(+), 40 deletions(-) diff --git a/tests/mocks/mock_dynamixel.py b/tests/mocks/mock_dynamixel.py index 0d100bb16..787380259 100644 --- a/tests/mocks/mock_dynamixel.py +++ b/tests/mocks/mock_dynamixel.py @@ -47,37 +47,37 @@ DXL_CRC_TABLE = [ # https://emanual.robotis.com/docs/en/dxl/protocol2/#instruction INSTRUCTION_TYPES = { - "Ping": 0x01, # Checks whether the Packet has arrived at a device with the same ID as the specified packet ID - "Read": 0x02, # Read data from the Device - "Write": 0x03, # Write data to the Device - "Reg_Write": 0x04, # Register the Instruction Packet in standby status; Packet can later be executed using the Action command - "Action": 0x05, # Executes a Packet that was registered beforehand using Reg Write - "Factory_Reset": 0x06, # Resets the Control Table to its initial factory default settings - "Reboot": 0x08, # Reboot the Device - "Clear": 0x10, # Reset certain information stored in memory - "Control_Table_Backup": 0x20, # Store current Control Table status data to a Backup or to restore backup EEPROM data. - "Status": 0x55, # Return packet sent following the execution of an Instruction Packet - "Sync_Read": 0x82, # Read data from multiple devices with the same Address with the same length at once - "Sync_Write": 0x83, # Write data to multiple devices with the same Address with the same length at once - "Fast_Sync_Read": 0x8A, # Read data from multiple devices with the same Address with the same length at once - "Bulk_Read": 0x92, # Read data from multiple devices with different Addresses with different lengths at once - "Bulk_Write": 0x93, # Write data to multiple devices with different Addresses with different lengths at once - "Fast_Bulk_Read": 0x9A, # Read data from multiple devices with different Addresses with different lengths at once + "Ping": dxl.INST_PING, # Checks whether the Packet has arrived at a device with the same ID as the specified packet ID + "Read": dxl.INST_READ, # Read data from the Device + "Write": dxl.INST_WRITE, # Write data to the Device + "Reg_Write": dxl.INST_REG_WRITE, # Register the Instruction Packet in standby status; Packet can later be executed using the Action command + "Action": dxl.INST_ACTION, # Executes a Packet that was registered beforehand using Reg Write + "Factory_Reset": dxl.INST_FACTORY_RESET, # Resets the Control Table to its initial factory default settings + "Reboot": dxl.INST_REBOOT, # Reboot the Device + "Clear": dxl.INST_CLEAR, # Reset certain information stored in memory + "Control_Table_Backup": 0x20, # Store current Control Table status data to a Backup or to restore backup EEPROM data. + "Status": dxl.INST_STATUS, # Return packet sent following the execution of an Instruction Packet + "Sync_Read": dxl.INST_SYNC_READ, # Read data from multiple devices with the same Address with the same length at once + "Sync_Write": dxl.INST_SYNC_WRITE, # Write data to multiple devices with the same Address with the same length at once + "Fast_Sync_Read": 0x8A, # Read data from multiple devices with the same Address with the same length at once + "Bulk_Read": dxl.INST_BULK_READ, # Read data from multiple devices with different Addresses with different lengths at once + "Bulk_Write": dxl.INST_BULK_WRITE, # Write data to multiple devices with different Addresses with different lengths at once + "Fast_Bulk_Read": 0x9A, # Read data from multiple devices with different Addresses with different lengths at once } # fmt: skip # https://emanual.robotis.com/docs/en/dxl/protocol2/#error ERROR_TYPE = { - "Success": 0x00, # No error - "Result_Fail": 0x01, # Failed to process the sent Instruction Packet - "Instruction_Error": 0x02, # An undefined Instruction has been usedAction has been used without Reg Write - "CRC_Error": 0x03, # The CRC of the sent Packet does not match the expected value - "Data_Range_Error": 0x04, # Data to be written to the specified Address is outside the range of the minimum/maximum value - "Data_Length_Error": 0x05, # Attempted to write Data that is shorter than the required data length of the specified Address - # (ex: when you attempt to only use 2 bytes of a register that has been defined as 4 bytes) - "Data_Limit_Error": 0x06, # Data to be written to the specified Address is outside of the configured Limit value - "Access_Error": 0x07, # Attempted to write a value to an Address that is Read Only or has not been defined - # Attempted to read a value from an Address that is Write Only or has not been defined - # Attempted to write a value to an EEPROM register while Torque was Enabled. + "Success": 0x00, # No error + "Result_Fail": dxl.ERRNUM_RESULT_FAIL, # Failed to process the sent Instruction Packet + "Instruction_Error": dxl.ERRNUM_INSTRUCTION, # An undefined Instruction has been usedAction has been used without Reg Write + "CRC_Error": dxl.ERRNUM_CRC, # The CRC of the sent Packet does not match the expected value + "Data_Range_Error": dxl.ERRNUM_DATA_RANGE, # Data to be written to the specified Address is outside the range of the minimum/maximum value + "Data_Length_Error": dxl.ERRNUM_DATA_LENGTH, # Attempted to write Data that is shorter than the required data length of the specified Address + # (ex: when you attempt to only use 2 bytes of a register that has been defined as 4 bytes) + "Data_Limit_Error": dxl.ERRNUM_DATA_LIMIT, # Data to be written to the specified Address is outside of the configured Limit value + "Access_Error": dxl.ERRNUM_ACCESS, # Attempted to write a value to an Address that is Read Only or has not been defined + # Attempted to read a value from an Address that is Write Only or has not been defined + # Attempted to write a value to an EEPROM register while Torque was Enabled. } # fmt: skip diff --git a/tests/mocks/mock_feetech.py b/tests/mocks/mock_feetech.py index 56437b027..82be9f20f 100644 --- a/tests/mocks/mock_feetech.py +++ b/tests/mocks/mock_feetech.py @@ -12,23 +12,23 @@ from .mock_serial_patch import WaitableStub # https://files.waveshare.com/upload/2/27/Communication_Protocol_User_Manual-EN%28191218-0923%29.pdf INSTRUCTION_TYPES = { - "Ping": 0x01, # Checks whether the Packet has arrived at a device with the same ID as the specified packet ID - "Read": 0x02, # Read data from the Device - "Write": 0x03, # Write data to the Device - "Reg_Write": 0x04, # Register the Instruction Packet in standby status; Packet can later be executed using the Action command - "Action": 0x05, # Executes a Packet that was registered beforehand using Reg Write - "Factory_Reset": 0x06, # Resets the Control Table to its initial factory default settings - "Sync_Read": 0x82, # Read data from multiple devices with the same Address with the same length at once - "Sync_Write": 0x83, # Write data to multiple devices with the same Address with the same length at once + "Read": scs.INST_PING, # Read data from the Device + "Ping": scs.INST_READ, # Checks whether the Packet has arrived at a device with the same ID as the specified packet ID + "Write": scs.INST_WRITE, # Write data to the Device + "Reg_Write": scs.INST_REG_WRITE, # Register the Instruction Packet in standby status; Packet can later be executed using the Action command + "Action": scs.INST_ACTION, # Executes a Packet that was registered beforehand using Reg Write + "Factory_Reset": 0x06, # Resets the Control Table to its initial factory default settings + "Sync_Write": scs.INST_SYNC_WRITE, # Write data to multiple devices with the same Address with the same length at once + "Sync_Read": scs.INST_SYNC_READ, # Read data from multiple devices with the same Address with the same length at once } # fmt: skip ERROR_TYPE = { "Success": 0x00, - "Voltage": 0x01, - "Angle": 0x02, - "Overheat": 0x04, - "Overele": 0x08, - "Overload": 0x20, + "Voltage": scs.ERRBIT_VOLTAGE, + "Angle": scs.ERRBIT_ANGLE, + "Overheat": scs.ERRBIT_OVERHEAT, + "Overele": scs.ERRBIT_OVERELE, + "Overload": scs.ERRBIT_OVERLOAD, } From 4005065223ceea9fd4bae7eb62528d68df5f8a53 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Thu, 10 Apr 2025 00:51:23 +0200 Subject: [PATCH 07/23] (nit) move write --- lerobot/common/motors/motors_bus.py | 88 ++++++++++++++--------------- tests/mocks/mock_dynamixel.py | 68 +++++++++++----------- tests/mocks/mock_feetech.py | 58 +++++++++---------- 3 files changed, 107 insertions(+), 107 deletions(-) diff --git a/lerobot/common/motors/motors_bus.py b/lerobot/common/motors/motors_bus.py index 9f3fcdb2c..16aa0402c 100644 --- a/lerobot/common/motors/motors_bus.py +++ b/lerobot/common/motors/motors_bus.py @@ -784,6 +784,50 @@ class MotorsBus(abc.ABC): return value, comm, error + def write( + self, data_name: str, motor: str, value: Value, *, normalize: bool = True, num_retry: int = 0 + ) -> None: + if not self.is_connected: + raise DeviceNotConnectedError( + f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." + ) + + id_ = self.motors[motor].id + model = self.motors[motor].model + addr, n_bytes = get_address(self.model_ctrl_table, model, data_name) + + if normalize and data_name in self.normalized_data: + value = self._unnormalize(data_name, {id_: value})[id_] + + value = self._encode_sign(data_name, {id_: value})[id_] + + comm, error = self._write(addr, n_bytes, id_, value, num_retry=num_retry) + if not self._is_comm_success(comm): + raise ConnectionError( + f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries." + f"\n{self.packet_handler.getTxRxResult(comm)}" + ) + elif self._is_error(error): + raise RuntimeError( + f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries." + f"\n{self.packet_handler.getRxPacketError(error)}" + ) + + def _write( + self, addr: int, n_bytes: int, motor_id: int, value: int, num_retry: int = 0 + ) -> tuple[int, int]: + data = self._split_int_to_bytes(value, n_bytes) + for n_try in range(1 + num_retry): + comm, error = self.packet_handler.writeTxRx(self.port_handler, motor_id, addr, n_bytes, data) + if self._is_comm_success(comm): + break + logger.debug( + f"Failed to sync write @{addr=} ({n_bytes=}) on id={motor_id} with {value=} ({n_try=}): " + + self.packet_handler.getTxRxResult(comm) + ) + + return comm, error + def sync_read( self, data_name: str, @@ -914,50 +958,6 @@ class MotorsBus(abc.ABC): data = self._split_int_to_bytes(value, n_bytes) self.sync_writer.addParam(id_, data) - def write( - self, data_name: str, motor: str, value: Value, *, normalize: bool = True, num_retry: int = 0 - ) -> None: - if not self.is_connected: - raise DeviceNotConnectedError( - f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." - ) - - id_ = self.motors[motor].id - model = self.motors[motor].model - addr, n_bytes = get_address(self.model_ctrl_table, model, data_name) - - if normalize and data_name in self.normalized_data: - value = self._unnormalize(data_name, {id_: value})[id_] - - value = self._encode_sign(data_name, {id_: value})[id_] - - comm, error = self._write(addr, n_bytes, id_, value, num_retry=num_retry) - if not self._is_comm_success(comm): - raise ConnectionError( - f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries." - f"\n{self.packet_handler.getTxRxResult(comm)}" - ) - elif self._is_error(error): - raise RuntimeError( - f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries." - f"\n{self.packet_handler.getRxPacketError(error)}" - ) - - def _write( - self, addr: int, n_bytes: int, motor_id: int, value: int, num_retry: int = 0 - ) -> tuple[int, int]: - data = self._split_int_to_bytes(value, n_bytes) - for n_try in range(1 + num_retry): - comm, error = self.packet_handler.writeTxRx(self.port_handler, motor_id, addr, n_bytes, data) - if self._is_comm_success(comm): - break - logger.debug( - f"Failed to sync write @{addr=} ({n_bytes=}) on id={motor_id} with {value=} ({n_try=}): " - + self.packet_handler.getTxRxResult(comm) - ) - - return comm, error - def disconnect(self, disable_torque: bool = True) -> None: if not self.is_connected: raise DeviceNotConnectedError( diff --git a/tests/mocks/mock_dynamixel.py b/tests/mocks/mock_dynamixel.py index 787380259..454d8da80 100644 --- a/tests/mocks/mock_dynamixel.py +++ b/tests/mocks/mock_dynamixel.py @@ -212,6 +212,40 @@ class MockInstructionPacket(MockDynamixelPacketv2): params, length = [], 3 return cls.build(dxl_id=dxl_id, params=params, length=length, instruct_type="Ping") + @classmethod + def write( + cls, + dxl_id: int, + value: int, + start_address: int, + data_length: int, + ) -> bytes: + """ + Builds a "Write" instruction. + https://emanual.robotis.com/docs/en/dxl/protocol2/#write-0x03 + + The parameters for Write (Protocol 2.0) are: + param[0] = start_address L + param[1] = start_address H + param[2] = 1st Byte + param[3] = 2nd Byte + ... + param[1+X] = X-th Byte + + And 'length' = data_length + 5, where: + +1 is for instruction byte, + +2 is for the length bytes, + +2 is for the CRC at the end. + """ + data = DynamixelMotorsBus._split_int_to_bytes(value, data_length) + params = [ + dxl.DXL_LOBYTE(start_address), + dxl.DXL_HIBYTE(start_address), + *data, + ] + length = data_length + 5 + return cls.build(dxl_id=dxl_id, params=params, length=length, instruct_type="Write") + @classmethod def sync_read( cls, @@ -293,40 +327,6 @@ class MockInstructionPacket(MockDynamixelPacketv2): length = len(ids_values) * (1 + data_length) + 7 return cls.build(dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Write") - @classmethod - def write( - cls, - dxl_id: int, - value: int, - start_address: int, - data_length: int, - ) -> bytes: - """ - Builds a "Write" instruction. - https://emanual.robotis.com/docs/en/dxl/protocol2/#write-0x03 - - The parameters for Write (Protocol 2.0) are: - param[0] = start_address L - param[1] = start_address H - param[2] = 1st Byte - param[3] = 2nd Byte - ... - param[1+X] = X-th Byte - - And 'length' = data_length + 5, where: - +1 is for instruction byte, - +2 is for the length bytes, - +2 is for the CRC at the end. - """ - data = DynamixelMotorsBus._split_int_to_bytes(value, data_length) - params = [ - dxl.DXL_LOBYTE(start_address), - dxl.DXL_HIBYTE(start_address), - *data, - ] - length = data_length + 5 - return cls.build(dxl_id=dxl_id, params=params, length=length, instruct_type="Write") - class MockStatusPacket(MockDynamixelPacketv2): """ diff --git a/tests/mocks/mock_feetech.py b/tests/mocks/mock_feetech.py index 82be9f20f..dfddaa1f7 100644 --- a/tests/mocks/mock_feetech.py +++ b/tests/mocks/mock_feetech.py @@ -115,6 +115,35 @@ class MockInstructionPacket(MockFeetechPacket): length = 4 return cls.build(scs_id=scs_id, params=params, length=length, instruct_type="Read") + @classmethod + def write( + cls, + scs_id: int, + value: int, + start_address: int, + data_length: int, + ) -> bytes: + """ + Builds a "Write" instruction. + + The parameters for Write are: + param[0] = start_address L + param[1] = start_address H + param[2] = 1st Byte + param[3] = 2nd Byte + ... + param[1+X] = X-th Byte + + And 'length' = data_length + 3, where: + +1 is for instruction byte, + +1 is for the length bytes, + +1 is for the checksum at the end. + """ + data = FeetechMotorsBus._split_int_to_bytes(value, data_length) + params = [start_address, *data] + length = data_length + 3 + return cls.build(scs_id=scs_id, params=params, length=length, instruct_type="Write") + @classmethod def sync_read( cls, @@ -178,35 +207,6 @@ class MockInstructionPacket(MockFeetechPacket): length = len(ids_values) * (1 + data_length) + 4 return cls.build(scs_id=scs.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Write") - @classmethod - def write( - cls, - scs_id: int, - value: int, - start_address: int, - data_length: int, - ) -> bytes: - """ - Builds a "Write" instruction. - - The parameters for Write are: - param[0] = start_address L - param[1] = start_address H - param[2] = 1st Byte - param[3] = 2nd Byte - ... - param[1+X] = X-th Byte - - And 'length' = data_length + 3, where: - +1 is for instruction byte, - +1 is for the length bytes, - +1 is for the checksum at the end. - """ - data = FeetechMotorsBus._split_int_to_bytes(value, data_length) - params = [start_address, *data] - length = data_length + 3 - return cls.build(scs_id=scs_id, params=params, length=length, instruct_type="Write") - class MockStatusPacket(MockFeetechPacket): """ From 12abc9ca864ccabed34a96720be978077d00d9d2 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Thu, 10 Apr 2025 00:53:17 +0200 Subject: [PATCH 08/23] Fix broadcast ping type hint --- lerobot/common/motors/motors_bus.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lerobot/common/motors/motors_bus.py b/lerobot/common/motors/motors_bus.py index 16aa0402c..4c2a836c3 100644 --- a/lerobot/common/motors/motors_bus.py +++ b/lerobot/common/motors/motors_bus.py @@ -722,9 +722,7 @@ class MotorsBus(abc.ABC): return model_number @abc.abstractmethod - def broadcast_ping( - self, num_retry: int = 0, raise_on_error: bool = False - ) -> dict[int, list[int, str]] | None: + def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None: pass def read( From 27cb0c40bdf01b48ee11113002565c3c98e7876e Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Thu, 10 Apr 2025 17:14:40 +0200 Subject: [PATCH 09/23] Add protocol 1 broadcast ping --- lerobot/common/motors/dynamixel/dynamixel.py | 27 ++--- lerobot/common/motors/feetech/feetech.py | 105 +++++++++++-------- lerobot/common/motors/feetech/tables.py | 2 +- lerobot/common/motors/motors_bus.py | 60 +++++------ 4 files changed, 99 insertions(+), 95 deletions(-) diff --git a/lerobot/common/motors/dynamixel/dynamixel.py b/lerobot/common/motors/dynamixel/dynamixel.py index dc16ba6ee..8f69b8b5f 100644 --- a/lerobot/common/motors/dynamixel/dynamixel.py +++ b/lerobot/common/motors/dynamixel/dynamixel.py @@ -125,13 +125,13 @@ class DynamixelMotorsBus(MotorsBus): for id_ in self.ids: self.write("Return_Delay_Time", id_, 0) - def _disable_torque(self, motors: list[NameOrID]) -> None: - for motor in motors: - self.write("Torque_Enable", motor, TorqueMode.DISABLED.value) + def disable_torque(self, motors: str | list[str] | None = None) -> None: + for name in self._get_names_list(motors): + self.write("Torque_Enable", name, TorqueMode.DISABLED.value) - def _enable_torque(self, motors: list[NameOrID]) -> None: - for motor in motors: - self.write("Torque_Enable", motor, TorqueMode.ENABLED.value) + def enable_torque(self, motors: str | list[str] | None = None) -> None: + for name in self._get_names_list(motors): + self.write("Torque_Enable", name, TorqueMode.ENABLED.value) def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]: for id_ in ids_values: @@ -167,22 +167,9 @@ class DynamixelMotorsBus(MotorsBus): return half_turn_homings @staticmethod - def _split_int_to_bytes(value: int, n_bytes: int) -> list[int]: - # Validate input - if value < 0: - raise ValueError(f"Negative values are not allowed: {value}") - - max_value = {1: 0xFF, 2: 0xFFFF, 4: 0xFFFFFFFF}.get(n_bytes) - if max_value is None: - raise NotImplementedError(f"Unsupported byte size: {n_bytes}. Expected [1, 2, 4].") - - if value > max_value: - raise ValueError(f"Value {value} exceeds the maximum for {n_bytes} bytes ({max_value}).") - + def _split_into_byte_chunks(value: int, n_bytes: int) -> list[int]: import dynamixel_sdk as dxl - # Note: No need to convert back into unsigned int, since this byte preprocessing - # already handles it for us. if n_bytes == 1: data = [value] elif n_bytes == 2: diff --git a/lerobot/common/motors/feetech/feetech.py b/lerobot/common/motors/feetech/feetech.py index 89c8ac9f7..f7557c972 100644 --- a/lerobot/common/motors/feetech/feetech.py +++ b/lerobot/common/motors/feetech/feetech.py @@ -139,15 +139,15 @@ class FeetechMotorsBus(MotorsBus): return half_turn_homings - def _disable_torque(self, motors: list[NameOrID]) -> None: - for motor in motors: - self.write("Torque_Enable", motor, TorqueMode.DISABLED.value) - self.write("Lock", motor, 0) + def disable_torque(self, motors: str | list[str] | None = None) -> None: + for name in self._get_names_list(motors): + self.write("Torque_Enable", name, TorqueMode.DISABLED.value) + self.write("Lock", name, 0) - def _enable_torque(self, motors: list[NameOrID]) -> None: - for motor in motors: - self.write("Torque_Enable", motor, TorqueMode.ENABLED.value) - self.write("Lock", motor, 1) + def enable_torque(self, motors: str | list[str] | None = None) -> None: + for name in self._get_names_list(motors): + self.write("Torque_Enable", name, TorqueMode.ENABLED.value) + self.write("Lock", name, 1) def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]: for id_ in ids_values: @@ -170,18 +170,7 @@ class FeetechMotorsBus(MotorsBus): return ids_values @staticmethod - def _split_int_to_bytes(value: int, n_bytes: int) -> list[int]: - # Validate input - if value < 0: - raise ValueError(f"Negative values are not allowed: {value}") - - max_value = {1: 0xFF, 2: 0xFFFF, 4: 0xFFFFFFFF}.get(n_bytes) - if max_value is None: - raise NotImplementedError(f"Unsupported byte size: {n_bytes}. Expected [1, 2, 4].") - - if value > max_value: - raise ValueError(f"Value {value} exceeds the maximum for {n_bytes} bytes ({max_value}).") - + def _split_into_byte_chunks(value: int, n_bytes: int) -> list[int]: import scservo_sdk as scs if n_bytes == 1: @@ -197,7 +186,23 @@ class FeetechMotorsBus(MotorsBus): ] return data - def _broadcast_ping(self) -> tuple[dict[int, int], int]: + def _broadcast_ping_p1(self, known_motors_only: bool = True, num_retry: int = 0) -> dict[int, int]: + if known_motors_only: + ids = self.ids + else: + import scservo_sdk as scs + + ids = range(scs.MAX_ID + 1) + + ids_models = {} + for id_ in ids: + model_number = self.ping(id_, num_retry) + if model_number is not None: + ids_models[id_] = model_number + + return ids_models + + def _broadcast_ping_p0(self) -> tuple[dict[int, int], int]: import scservo_sdk as scs data_list = {} @@ -251,7 +256,7 @@ class FeetechMotorsBus(MotorsBus): for idx in range(2, status_length - 1): # except header & checksum checksum += rxpacket[idx] - checksum = scs.SCS_LOBYTE(~checksum) + checksum = ~checksum & 0xFF if rxpacket[status_length - 1] == checksum: result = scs.COMM_SUCCESS data_list[rxpacket[scs.PKT_ID]] = rxpacket[scs.PKT_ERROR] @@ -272,24 +277,31 @@ class FeetechMotorsBus(MotorsBus): rx_length = rx_length - idx def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None: - for n_try in range(1 + num_retry): - ids_status, comm = self._broadcast_ping() - if self._is_comm_success(comm): - break - logger.debug(f"Broadcast ping failed on port '{self.port}' ({n_try=})") - logger.debug(self.packet_handler.getTxRxResult(comm)) + if self.protocol_version == 0: + for n_try in range(1 + num_retry): + ids_status, comm = self._broadcast_ping_p0() + if self._is_comm_success(comm): + break + logger.debug(f"Broadcast ping failed on port '{self.port}' ({n_try=})") + logger.debug(self.packet_handler.getTxRxResult(comm)) - if not self._is_comm_success(comm): - if raise_on_error: - raise ConnectionError(self.packet_handler.getTxRxResult(comm)) - return + if not self._is_comm_success(comm): + if raise_on_error: + raise ConnectionError(self.packet_handler.getTxRxResult(comm)) + return - ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)} - if ids_errors: - display_dict = {id_: self.packet_handler.getRxPacketError(err) for id_, err in ids_errors.items()} - logger.error(f"Some motors found returned an error status:\n{pformat(display_dict, indent=4)}") + ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)} + if ids_errors: + display_dict = { + id_: self.packet_handler.getRxPacketError(err) for id_, err in ids_errors.items() + } + logger.error( + f"Some motors found returned an error status:\n{pformat(display_dict, indent=4)}" + ) - return self._get_model_number(list(ids_status), raise_on_error) + return self._get_model_number(list(ids_status), raise_on_error) + else: + return self._broadcast_ping_p1(num_retry=num_retry) def _get_firmware_version(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]: # comm, major = self._sync_read(*FIRMWARE_MAJOR_VERSION, motor_ids) @@ -328,11 +340,20 @@ class FeetechMotorsBus(MotorsBus): # return # return {id_: f"{major[id_]}.{minor[id_]}" for id_ in motor_ids} + if self.protocol_version == 1: + model_numbers = {} + for id_ in motor_ids: + model_nb, comm, error = self._read(*MODEL_NUMBER, id_) + if self._is_comm_success(comm) and not self._is_error(error): + model_numbers[id_] = model_nb + elif raise_on_error: + raise Exception # FIX - comm, model_numbers = self._sync_read(*MODEL_NUMBER, motor_ids) - if not self._is_comm_success(comm): - if raise_on_error: - raise ConnectionError(self.packet_handler.getTxRxResult(comm)) - return + else: + comm, model_numbers = self._sync_read(*MODEL_NUMBER, motor_ids) + if not self._is_comm_success(comm): + if raise_on_error: + raise ConnectionError(self.packet_handler.getTxRxResult(comm)) + return return model_numbers diff --git a/lerobot/common/motors/feetech/tables.py b/lerobot/common/motors/feetech/tables.py index 0fa2fa84f..e6d08cf82 100644 --- a/lerobot/common/motors/feetech/tables.py +++ b/lerobot/common/motors/feetech/tables.py @@ -199,5 +199,5 @@ MODEL_NUMBER_TABLE = { "sts3215": 777, "sts3250": None, "sm8512bl": None, - "scs0009": None, + "scs0009": 1284, } diff --git a/lerobot/common/motors/motors_bus.py b/lerobot/common/motors/motors_bus.py index 4c2a836c3..3c64be7b6 100644 --- a/lerobot/common/motors/motors_bus.py +++ b/lerobot/common/motors/motors_bus.py @@ -445,34 +445,12 @@ class MotorsBus(abc.ABC): def configure_motors(self) -> None: pass - def disable_torque(self, motors: NameOrID | list[NameOrID] | None = None) -> None: - pass - if motors is None: - motors = self.names - elif isinstance(motors, (str, int)): - motors = [motors] - elif not isinstance(motors, list): - raise TypeError(motors) - - self._disable_torque(motors) - - def enable_torque(self, motors: NameOrID | list[NameOrID] | None = None) -> None: - pass - if motors is None: - motors = self.names - elif isinstance(motors, (str, int)): - motors = [motors] - elif not isinstance(motors, list): - raise TypeError(motors) - - self._enable_torque(motors) - @abc.abstractmethod - def _enable_torque(self, motors: list[NameOrID]) -> None: + def disable_torque(self, motors: str | list[str] | None = None) -> None: pass @abc.abstractmethod - def _disable_torque(self, motors: list[NameOrID]) -> None: + def enable_torque(self, motors: str | list[str] | None = None) -> None: pass def set_timeout(self, timeout_ms: int | None = None): @@ -620,6 +598,8 @@ class MotorsBus(abc.ABC): return mins, maxes def _normalize(self, data_name: str, ids_values: dict[int, int]) -> dict[int, float]: + if not self.calibration: + raise RuntimeError(f"{self} has no calibration registered.") normalized_values = {} for id_, val in ids_values.items(): name = self._id_to_name(id_) @@ -662,11 +642,10 @@ class MotorsBus(abc.ABC): def _decode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]: pass - @staticmethod - @abc.abstractmethod - def _split_int_to_bytes(value: int, n_bytes: int) -> list[int]: + def _serialize_data(self, value: int, n_bytes: int) -> list[int]: """ - Splits an unsigned integer into a list of bytes in little-endian order. + Converts an unsigned integer value into a list of byte-sized integers to be sent via a communication + protocol. Depending on the protocol, split values can be in big-endian or little-endian order. This function extracts the individual bytes of an integer based on the specified number of bytes (`n_bytes`). The output is a list of integers, @@ -678,7 +657,8 @@ class MotorsBus(abc.ABC): Args: value (int): The unsigned integer to be converted into a byte list. Must be within the valid range for the specified `n_bytes`. - n_bytes (int): The number of bytes to use for conversion. Supported values: + n_bytes (int): The number of bytes to use for conversion. Supported values for both Feetech and + Dynamixel: - 1 (for values 0 to 255) - 2 (for values 0 to 65,535) - 4 (for values 0 to 4,294,967,295) @@ -690,7 +670,7 @@ class MotorsBus(abc.ABC): Returns: list[int]: A list of integers, each representing a byte in **little-endian order**. - Examples: + Examples (for a little-endian protocol): >>> split_int_bytes(0x12, 1) [18] >>> split_int_bytes(0x1234, 2) @@ -698,6 +678,22 @@ class MotorsBus(abc.ABC): >>> split_int_bytes(0x12345678, 4) [120, 86, 52, 18] # 0x12345678 → 0x78 0x56 0x34 0x12 """ + if value < 0: + raise ValueError(f"Negative values are not allowed: {value}") + + max_value = {1: 0xFF, 2: 0xFFFF, 4: 0xFFFFFFFF}.get(n_bytes) + if max_value is None: + raise NotImplementedError(f"Unsupported byte size: {n_bytes}. Expected [1, 2, 4].") + + if value > max_value: + raise ValueError(f"Value {value} exceeds the maximum for {n_bytes} bytes ({max_value}).") + + return self._split_into_byte_chunks(value, n_bytes) + + @staticmethod + @abc.abstractmethod + def _split_into_byte_chunks(value: int, n_bytes: int) -> list[int]: + """Convert an integer into a list of byte-sized integers.""" pass def ping(self, motor: NameOrID, num_retry: int = 0, raise_on_error: bool = False) -> int | None: @@ -814,7 +810,7 @@ class MotorsBus(abc.ABC): def _write( self, addr: int, n_bytes: int, motor_id: int, value: int, num_retry: int = 0 ) -> tuple[int, int]: - data = self._split_int_to_bytes(value, n_bytes) + data = self._serialize_data(value, n_bytes) for n_try in range(1 + num_retry): comm, error = self.packet_handler.writeTxRx(self.port_handler, motor_id, addr, n_bytes, data) if self._is_comm_success(comm): @@ -953,7 +949,7 @@ class MotorsBus(abc.ABC): self.sync_writer.start_address = addr self.sync_writer.data_length = n_bytes for id_, value in ids_values.items(): - data = self._split_int_to_bytes(value, n_bytes) + data = self._serialize_data(value, n_bytes) self.sync_writer.addParam(id_, data) def disconnect(self, disable_torque: bool = True) -> None: From d32daebf75b84a69ff9b4a0a7e3b9582ceb20257 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Fri, 11 Apr 2025 11:01:12 +0200 Subject: [PATCH 10/23] Refactor & add _serialize_data --- lerobot/common/motors/dynamixel/dynamixel.py | 8 +- lerobot/common/motors/feetech/feetech.py | 8 +- lerobot/common/motors/motors_bus.py | 116 +++++++------------ tests/mocks/mock_dynamixel.py | 6 +- tests/mocks/mock_feetech.py | 8 +- tests/motors/test_dynamixel.py | 18 +-- tests/motors/test_feetech.py | 18 +-- 7 files changed, 78 insertions(+), 104 deletions(-) diff --git a/lerobot/common/motors/dynamixel/dynamixel.py b/lerobot/common/motors/dynamixel/dynamixel.py index 8f69b8b5f..a710afdec 100644 --- a/lerobot/common/motors/dynamixel/dynamixel.py +++ b/lerobot/common/motors/dynamixel/dynamixel.py @@ -167,14 +167,14 @@ class DynamixelMotorsBus(MotorsBus): return half_turn_homings @staticmethod - def _split_into_byte_chunks(value: int, n_bytes: int) -> list[int]: + def _split_into_byte_chunks(value: int, length: int) -> list[int]: import dynamixel_sdk as dxl - if n_bytes == 1: + if length == 1: data = [value] - elif n_bytes == 2: + elif length == 2: data = [dxl.DXL_LOBYTE(value), dxl.DXL_HIBYTE(value)] - elif n_bytes == 4: + elif length == 4: data = [ dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)), dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)), diff --git a/lerobot/common/motors/feetech/feetech.py b/lerobot/common/motors/feetech/feetech.py index f7557c972..a0796f9c6 100644 --- a/lerobot/common/motors/feetech/feetech.py +++ b/lerobot/common/motors/feetech/feetech.py @@ -170,14 +170,14 @@ class FeetechMotorsBus(MotorsBus): return ids_values @staticmethod - def _split_into_byte_chunks(value: int, n_bytes: int) -> list[int]: + def _split_into_byte_chunks(value: int, length: int) -> list[int]: import scservo_sdk as scs - if n_bytes == 1: + if length == 1: data = [value] - elif n_bytes == 2: + elif length == 2: data = [scs.SCS_LOBYTE(value), scs.SCS_HIBYTE(value)] - elif n_bytes == 4: + elif length == 4: data = [ scs.SCS_LOBYTE(scs.SCS_LOWORD(value)), scs.SCS_HIBYTE(scs.SCS_LOWORD(value)), diff --git a/lerobot/common/motors/motors_bus.py b/lerobot/common/motors/motors_bus.py index 3c64be7b6..7bc8a4ae0 100644 --- a/lerobot/common/motors/motors_bus.py +++ b/lerobot/common/motors/motors_bus.py @@ -642,57 +642,31 @@ class MotorsBus(abc.ABC): def _decode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]: pass - def _serialize_data(self, value: int, n_bytes: int) -> list[int]: + def _serialize_data(self, value: int, length: int) -> list[int]: """ Converts an unsigned integer value into a list of byte-sized integers to be sent via a communication protocol. Depending on the protocol, split values can be in big-endian or little-endian order. - This function extracts the individual bytes of an integer based on the - specified number of bytes (`n_bytes`). The output is a list of integers, - each representing a byte (0-255). - - **Byte order:** The function returns bytes in **little-endian format**, - meaning the least significant byte (LSB) comes first. - - Args: - value (int): The unsigned integer to be converted into a byte list. Must be within - the valid range for the specified `n_bytes`. - n_bytes (int): The number of bytes to use for conversion. Supported values for both Feetech and - Dynamixel: - - 1 (for values 0 to 255) - - 2 (for values 0 to 65,535) - - 4 (for values 0 to 4,294,967,295) - - Raises: - ValueError: If `value` is negative or exceeds the maximum allowed for `n_bytes`. - NotImplementedError: If `n_bytes` is not 1, 2, or 4. - - Returns: - list[int]: A list of integers, each representing a byte in **little-endian order**. - - Examples (for a little-endian protocol): - >>> split_int_bytes(0x12, 1) - [18] - >>> split_int_bytes(0x1234, 2) - [52, 18] # 0x1234 → 0x34 0x12 (little-endian) - >>> split_int_bytes(0x12345678, 4) - [120, 86, 52, 18] # 0x12345678 → 0x78 0x56 0x34 0x12 + Supported data length for both Feetech and Dynamixel: + - 1 (for values 0 to 255) + - 2 (for values 0 to 65,535) + - 4 (for values 0 to 4,294,967,295) """ if value < 0: raise ValueError(f"Negative values are not allowed: {value}") - max_value = {1: 0xFF, 2: 0xFFFF, 4: 0xFFFFFFFF}.get(n_bytes) + max_value = {1: 0xFF, 2: 0xFFFF, 4: 0xFFFFFFFF}.get(length) if max_value is None: - raise NotImplementedError(f"Unsupported byte size: {n_bytes}. Expected [1, 2, 4].") + raise NotImplementedError(f"Unsupported byte size: {length}. Expected [1, 2, 4].") if value > max_value: - raise ValueError(f"Value {value} exceeds the maximum for {n_bytes} bytes ({max_value}).") + raise ValueError(f"Value {value} exceeds the maximum for {length} bytes ({max_value}).") - return self._split_into_byte_chunks(value, n_bytes) + return self._split_into_byte_chunks(value, length) @staticmethod @abc.abstractmethod - def _split_into_byte_chunks(value: int, n_bytes: int) -> list[int]: + def _split_into_byte_chunks(value: int, length: int) -> list[int]: """Convert an integer into a list of byte-sized integers.""" pass @@ -736,9 +710,9 @@ class MotorsBus(abc.ABC): id_ = self.motors[motor].id model = self.motors[motor].model - addr, n_bytes = get_address(self.model_ctrl_table, model, data_name) + addr, length = get_address(self.model_ctrl_table, model, data_name) - value, comm, error = self._read(addr, n_bytes, id_, num_retry=num_retry) + value, comm, error = self._read(addr, length, id_, num_retry=num_retry) if not self._is_comm_success(comm): raise ConnectionError( f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries." @@ -757,22 +731,22 @@ class MotorsBus(abc.ABC): return id_value[id_] - def _read(self, addr: int, n_bytes: int, motor_id: int, num_retry: int = 0) -> tuple[int, int]: - if n_bytes == 1: + def _read(self, address: int, length: int, motor_id: int, num_retry: int = 0) -> tuple[int, int]: + if length == 1: read_fn = self.packet_handler.read1ByteTxRx - elif n_bytes == 2: + elif length == 2: read_fn = self.packet_handler.read2ByteTxRx - elif n_bytes == 4: + elif length == 4: read_fn = self.packet_handler.read4ByteTxRx else: - raise ValueError(n_bytes) + raise ValueError(length) for n_try in range(1 + num_retry): - value, comm, error = read_fn(self.port_handler, motor_id, addr) + value, comm, error = read_fn(self.port_handler, motor_id, address) if self._is_comm_success(comm): break logger.debug( - f"Failed to read @{addr=} ({n_bytes=}) on {motor_id=} ({n_try=}): " + f"Failed to read @{address=} ({length=}) on {motor_id=} ({n_try=}): " + self.packet_handler.getTxRxResult(comm) ) @@ -788,14 +762,14 @@ class MotorsBus(abc.ABC): id_ = self.motors[motor].id model = self.motors[motor].model - addr, n_bytes = get_address(self.model_ctrl_table, model, data_name) + addr, length = get_address(self.model_ctrl_table, model, data_name) if normalize and data_name in self.normalized_data: value = self._unnormalize(data_name, {id_: value})[id_] value = self._encode_sign(data_name, {id_: value})[id_] - comm, error = self._write(addr, n_bytes, id_, value, num_retry=num_retry) + comm, error = self._write(addr, length, id_, value, num_retry=num_retry) if not self._is_comm_success(comm): raise ConnectionError( f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries." @@ -808,15 +782,15 @@ class MotorsBus(abc.ABC): ) def _write( - self, addr: int, n_bytes: int, motor_id: int, value: int, num_retry: int = 0 + self, addr: int, length: int, motor_id: int, value: int, num_retry: int = 0 ) -> tuple[int, int]: - data = self._serialize_data(value, n_bytes) + data = self._serialize_data(value, length) for n_try in range(1 + num_retry): - comm, error = self.packet_handler.writeTxRx(self.port_handler, motor_id, addr, n_bytes, data) + comm, error = self.packet_handler.writeTxRx(self.port_handler, motor_id, addr, length, data) if self._is_comm_success(comm): break logger.debug( - f"Failed to sync write @{addr=} ({n_bytes=}) on id={motor_id} with {value=} ({n_try=}): " + f"Failed to sync write @{addr=} ({length=}) on id={motor_id} with {value=} ({n_try=}): " + self.packet_handler.getTxRxResult(comm) ) @@ -845,9 +819,9 @@ class MotorsBus(abc.ABC): assert_same_address(self.model_ctrl_table, models, data_name) model = next(iter(models)) - addr, n_bytes = get_address(self.model_ctrl_table, model, data_name) + addr, length = get_address(self.model_ctrl_table, model, data_name) - comm, ids_values = self._sync_read(addr, n_bytes, ids, num_retry=num_retry) + comm, ids_values = self._sync_read(addr, length, ids, num_retry=num_retry) if not self._is_comm_success(comm): raise ConnectionError( f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries." @@ -862,25 +836,25 @@ class MotorsBus(abc.ABC): return {self._id_to_name(id_): value for id_, value in ids_values.items()} def _sync_read( - self, addr: int, n_bytes: int, motor_ids: list[int], num_retry: int = 0 + self, addr: int, length: int, motor_ids: list[int], num_retry: int = 0 ) -> tuple[int, dict[int, int]]: - self._setup_sync_reader(motor_ids, addr, n_bytes) + self._setup_sync_reader(motor_ids, addr, length) for n_try in range(1 + num_retry): comm = self.sync_reader.txRxPacket() if self._is_comm_success(comm): break logger.debug( - f"Failed to sync read @{addr=} ({n_bytes=}) on {motor_ids=} ({n_try=}): " + f"Failed to sync read @{addr=} ({length=}) on {motor_ids=} ({n_try=}): " + self.packet_handler.getTxRxResult(comm) ) - values = {id_: self.sync_reader.getData(id_, addr, n_bytes) for id_ in motor_ids} + values = {id_: self.sync_reader.getData(id_, addr, length) for id_ in motor_ids} return comm, values - def _setup_sync_reader(self, motor_ids: list[int], addr: int, n_bytes: int) -> None: + def _setup_sync_reader(self, motor_ids: list[int], addr: int, length: int) -> None: self.sync_reader.clearParam() self.sync_reader.start_address = addr - self.sync_reader.data_length = n_bytes + self.sync_reader.data_length = length for id_ in motor_ids: self.sync_reader.addParam(id_) @@ -888,15 +862,15 @@ class MotorsBus(abc.ABC): # Would have to handle the logic of checking if a packet has been sent previously though but doable. # This could be at the cost of increase latency between the moment the data is produced by the motors and # the moment it is used by a policy. - # def _async_read(self, motor_ids: list[int], address: int, n_bytes: int): - # if self.sync_reader.start_address != address or self.sync_reader.data_length != n_bytes or ...: - # self._setup_sync_reader(motor_ids, address, n_bytes) + # def _async_read(self, motor_ids: list[int], address: int, length: int): + # if self.sync_reader.start_address != address or self.sync_reader.data_length != length or ...: + # self._setup_sync_reader(motor_ids, address, length) # else: # self.sync_reader.rxPacket() # self.sync_reader.txPacket() # for id_ in motor_ids: - # value = self.sync_reader.getData(id_, address, n_bytes) + # value = self.sync_reader.getData(id_, address, length) def sync_write( self, @@ -917,39 +891,39 @@ class MotorsBus(abc.ABC): assert_same_address(self.model_ctrl_table, models, data_name) model = next(iter(models)) - addr, n_bytes = get_address(self.model_ctrl_table, model, data_name) + addr, length = get_address(self.model_ctrl_table, model, data_name) if normalize and data_name in self.normalized_data: ids_values = self._unnormalize(data_name, ids_values) ids_values = self._encode_sign(data_name, ids_values) - comm = self._sync_write(addr, n_bytes, ids_values, num_retry=num_retry) + comm = self._sync_write(addr, length, ids_values, num_retry=num_retry) if not self._is_comm_success(comm): raise ConnectionError( f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries." f"\n{self.packet_handler.getTxRxResult(comm)}" ) - def _sync_write(self, addr: int, n_bytes: int, ids_values: dict[int, int], num_retry: int = 0) -> int: - self._setup_sync_writer(ids_values, addr, n_bytes) + def _sync_write(self, addr: int, length: int, ids_values: dict[int, int], num_retry: int = 0) -> int: + self._setup_sync_writer(ids_values, addr, length) for n_try in range(1 + num_retry): comm = self.sync_writer.txPacket() if self._is_comm_success(comm): break logger.debug( - f"Failed to sync write @{addr=} ({n_bytes=}) with {ids_values=} ({n_try=}): " + f"Failed to sync write @{addr=} ({length=}) with {ids_values=} ({n_try=}): " + self.packet_handler.getTxRxResult(comm) ) return comm - def _setup_sync_writer(self, ids_values: dict[int, int], addr: int, n_bytes: int) -> None: + def _setup_sync_writer(self, ids_values: dict[int, int], addr: int, length: int) -> None: self.sync_writer.clearParam() self.sync_writer.start_address = addr - self.sync_writer.data_length = n_bytes + self.sync_writer.data_length = length for id_, value in ids_values.items(): - data = self._serialize_data(value, n_bytes) + data = self._serialize_data(value, length) self.sync_writer.addParam(id_, data) def disconnect(self, disable_torque: bool = True) -> None: diff --git a/tests/mocks/mock_dynamixel.py b/tests/mocks/mock_dynamixel.py index 454d8da80..feae051bb 100644 --- a/tests/mocks/mock_dynamixel.py +++ b/tests/mocks/mock_dynamixel.py @@ -237,7 +237,7 @@ class MockInstructionPacket(MockDynamixelPacketv2): +2 is for the length bytes, +2 is for the CRC at the end. """ - data = DynamixelMotorsBus._split_int_to_bytes(value, data_length) + data = DynamixelMotorsBus._split_into_byte_chunks(value, data_length) params = [ dxl.DXL_LOBYTE(start_address), dxl.DXL_HIBYTE(start_address), @@ -315,7 +315,7 @@ class MockInstructionPacket(MockDynamixelPacketv2): """ data = [] for id_, value in ids_values.items(): - split_value = DynamixelMotorsBus._split_int_to_bytes(value, data_length) + split_value = DynamixelMotorsBus._split_into_byte_chunks(value, data_length) data += [id_, *split_value] params = [ dxl.DXL_LOBYTE(start_address), @@ -389,7 +389,7 @@ class MockStatusPacket(MockDynamixelPacketv2): Returns: bytes: The raw 'Present_Position' status packet ready to be sent through serial. """ - params = DynamixelMotorsBus._split_int_to_bytes(value, param_length) + params = DynamixelMotorsBus._split_into_byte_chunks(value, param_length) length = param_length + 4 return cls.build(dxl_id, params=params, length=length) diff --git a/tests/mocks/mock_feetech.py b/tests/mocks/mock_feetech.py index dfddaa1f7..57bd8cbc7 100644 --- a/tests/mocks/mock_feetech.py +++ b/tests/mocks/mock_feetech.py @@ -49,7 +49,7 @@ class MockFeetechPacket(abc.ABC): for id_ in range(2, len(packet) - 1): # except header & checksum checksum += packet[id_] - packet[-1] = scs.SCS_LOBYTE(~checksum) + packet[-1] = ~checksum & 0xFF return packet @@ -139,7 +139,7 @@ class MockInstructionPacket(MockFeetechPacket): +1 is for the length bytes, +1 is for the checksum at the end. """ - data = FeetechMotorsBus._split_int_to_bytes(value, data_length) + data = FeetechMotorsBus._split_into_byte_chunks(value, data_length) params = [start_address, *data] length = data_length + 3 return cls.build(scs_id=scs_id, params=params, length=length, instruct_type="Write") @@ -201,7 +201,7 @@ class MockInstructionPacket(MockFeetechPacket): """ data = [] for id_, value in ids_values.items(): - split_value = FeetechMotorsBus._split_int_to_bytes(value, data_length) + split_value = FeetechMotorsBus._split_into_byte_chunks(value, data_length) data += [id_, *split_value] params = [start_address, data_length, *data] length = len(ids_values) * (1 + data_length) + 4 @@ -258,7 +258,7 @@ class MockStatusPacket(MockFeetechPacket): Returns: bytes: The raw 'Sync Read' status packet ready to be sent through serial. """ - params = FeetechMotorsBus._split_int_to_bytes(value, param_length) + params = FeetechMotorsBus._split_into_byte_chunks(value, param_length) length = param_length + 2 return cls.build(scs_id, params=params, length=length) diff --git a/tests/motors/test_dynamixel.py b/tests/motors/test_dynamixel.py index 6fd0e3a7c..e047e7c1c 100644 --- a/tests/motors/test_dynamixel.py +++ b/tests/motors/test_dynamixel.py @@ -62,7 +62,7 @@ def test_autouse_patch(): @pytest.mark.parametrize( - "value, n_bytes, expected", + "value, length, expected", [ (0x12, 1, [0x12]), (0x1234, 2, [0x34, 0x12]), @@ -86,24 +86,24 @@ def test_autouse_patch(): "max four bytes", ], ) # fmt: skip -def test_split_int_to_bytes(value, n_bytes, expected): - assert DynamixelMotorsBus._split_int_to_bytes(value, n_bytes) == expected +def test_serialize_data(value, length, expected): + assert DynamixelMotorsBus._serialize_data(value, length) == expected -def test_split_int_to_bytes_invalid_n_bytes(): +def test_serialize_data_invalid_length(): with pytest.raises(NotImplementedError): - DynamixelMotorsBus._split_int_to_bytes(100, 3) + DynamixelMotorsBus._serialize_data(100, 3) -def test_split_int_to_bytes_negative_numbers(): +def test_serialize_data_negative_numbers(): with pytest.raises(ValueError): - neg = DynamixelMotorsBus._split_int_to_bytes(-1, 1) + neg = DynamixelMotorsBus._serialize_data(-1, 1) print(neg) -def test_split_int_to_bytes_large_number(): +def test_serialize_data_large_number(): with pytest.raises(ValueError): - DynamixelMotorsBus._split_int_to_bytes(2**32, 4) # 4-byte max is 0xFFFFFFFF + DynamixelMotorsBus._serialize_data(2**32, 4) # 4-byte max is 0xFFFFFFFF def test_abc_implementation(dummy_motors): diff --git a/tests/motors/test_feetech.py b/tests/motors/test_feetech.py index 5372c37ad..da8194646 100644 --- a/tests/motors/test_feetech.py +++ b/tests/motors/test_feetech.py @@ -61,7 +61,7 @@ def test_autouse_patch(): @pytest.mark.parametrize( - "value, n_bytes, expected", + "value, length, expected", [ (0x12, 1, [0x12]), (0x1234, 2, [0x34, 0x12]), @@ -85,24 +85,24 @@ def test_autouse_patch(): "max four bytes", ], ) # fmt: skip -def test_split_int_to_bytes(value, n_bytes, expected): - assert FeetechMotorsBus._split_int_to_bytes(value, n_bytes) == expected +def test_serialize_data(value, length, expected): + assert FeetechMotorsBus._serialize_data(value, length) == expected -def test_split_int_to_bytes_invalid_n_bytes(): +def test_serialize_data_invalid_length(): with pytest.raises(NotImplementedError): - FeetechMotorsBus._split_int_to_bytes(100, 3) + FeetechMotorsBus._serialize_data(100, 3) -def test_split_int_to_bytes_negative_numbers(): +def test_serialize_data_negative_numbers(): with pytest.raises(ValueError): - neg = FeetechMotorsBus._split_int_to_bytes(-1, 1) + neg = FeetechMotorsBus._serialize_data(-1, 1) print(neg) -def test_split_int_to_bytes_large_number(): +def test_serialize_data_large_number(): with pytest.raises(ValueError): - FeetechMotorsBus._split_int_to_bytes(2**32, 4) # 4-byte max is 0xFFFFFFFF + FeetechMotorsBus._serialize_data(2**32, 4) # 4-byte max is 0xFFFFFFFF def test_abc_implementation(dummy_motors): From 0464dc91b3526ce6e0595739901a53629e022d0a Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Fri, 11 Apr 2025 11:02:01 +0200 Subject: [PATCH 11/23] Add feetech sm8512bl --- lerobot/common/motors/feetech/tables.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lerobot/common/motors/feetech/tables.py b/lerobot/common/motors/feetech/tables.py index e6d08cf82..176033174 100644 --- a/lerobot/common/motors/feetech/tables.py +++ b/lerobot/common/motors/feetech/tables.py @@ -150,7 +150,7 @@ MODEL_RESOLUTION = { "scs_series": 1024, "sts3215": 4096, "sts3250": 4096, - "sm8512bl": 4096, + "sm8512bl": 65536, "scs0009": 1024, } @@ -198,6 +198,6 @@ SCAN_BAUDRATES = [ MODEL_NUMBER_TABLE = { "sts3215": 777, "sts3250": None, - "sm8512bl": None, + "sm8512bl": 11272, "scs0009": 1284, } From 4ca92a28e9df839b3c13bb2d1670791c20bc9737 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Fri, 11 Apr 2025 11:02:54 +0200 Subject: [PATCH 12/23] Make feetech broadcast ping faster in protocol 1 --- lerobot/common/motors/feetech/feetech.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/lerobot/common/motors/feetech/feetech.py b/lerobot/common/motors/feetech/feetech.py index a0796f9c6..a89f7fb91 100644 --- a/lerobot/common/motors/feetech/feetech.py +++ b/lerobot/common/motors/feetech/feetech.py @@ -186,7 +186,9 @@ class FeetechMotorsBus(MotorsBus): ] return data - def _broadcast_ping_p1(self, known_motors_only: bool = True, num_retry: int = 0) -> dict[int, int]: + def _broadcast_ping_p1( + self, known_motors_only: bool = True, n_motors: int | None = None, num_retry: int = 0 + ) -> dict[int, int]: if known_motors_only: ids = self.ids else: @@ -195,10 +197,14 @@ class FeetechMotorsBus(MotorsBus): ids = range(scs.MAX_ID + 1) ids_models = {} + motors_found = 0 for id_ in ids: model_number = self.ping(id_, num_retry) if model_number is not None: ids_models[id_] = model_number + motors_found += 1 + if motors_found >= n_motors: + break return ids_models From 0a7f51f0daf114fe96afc825b8c8a100d61274c2 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Fri, 11 Apr 2025 11:03:09 +0200 Subject: [PATCH 13/23] Cleanup --- lerobot/common/motors/feetech/feetech.py | 27 ------------------------ 1 file changed, 27 deletions(-) diff --git a/lerobot/common/motors/feetech/feetech.py b/lerobot/common/motors/feetech/feetech.py index a89f7fb91..a6b0c380b 100644 --- a/lerobot/common/motors/feetech/feetech.py +++ b/lerobot/common/motors/feetech/feetech.py @@ -310,20 +310,6 @@ class FeetechMotorsBus(MotorsBus): return self._broadcast_ping_p1(num_retry=num_retry) def _get_firmware_version(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]: - # comm, major = self._sync_read(*FIRMWARE_MAJOR_VERSION, motor_ids) - # if not self._is_comm_success(comm): - # if raise_on_error: - # raise ConnectionError(self.packet_handler.getTxRxResult(comm)) - # return - - # comm, minor = self._sync_read(*FIRMWARE_MINOR_VERSION, motor_ids) - # if not self._is_comm_success(comm): - # if raise_on_error: - # raise ConnectionError(self.packet_handler.getTxRxResult(comm)) - # return - - # return {id_: f"{major[id_]}.{minor[id_]}" for id_ in motor_ids} - comm, firmware_versions = self._sync_read(*FIRMWARE_VERSION, motor_ids) if not self._is_comm_success(comm): if raise_on_error: @@ -333,19 +319,6 @@ class FeetechMotorsBus(MotorsBus): return firmware_versions def _get_model_number(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]: - # comm, major = self._sync_read(*MODEL_MAJOR_VERSION, motor_ids) - # if not self._is_comm_success(comm): - # if raise_on_error: - # raise ConnectionError(self.packet_handler.getTxRxResult(comm)) - # return - - # comm, minor = self._sync_read(*MODEL_MINOR_VERSION, motor_ids) - # if not self._is_comm_success(comm): - # if raise_on_error: - # raise ConnectionError(self.packet_handler.getTxRxResult(comm)) - # return - - # return {id_: f"{major[id_]}.{minor[id_]}" for id_ in motor_ids} if self.protocol_version == 1: model_numbers = {} for id_ in motor_ids: From 9e57ec7837891ab2b25c76304e071eeabc3c5709 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Fri, 11 Apr 2025 11:58:09 +0200 Subject: [PATCH 14/23] Add support for feetech protocol 1 to _split_into_byte_chunks --- lerobot/common/motors/dynamixel/dynamixel.py | 35 ++++++----- lerobot/common/motors/feetech/feetech.py | 35 ++++++----- lerobot/common/motors/motors_bus.py | 3 +- tests/mocks/mock_dynamixel.py | 9 +-- tests/mocks/mock_feetech.py | 10 +-- tests/motors/test_dynamixel.py | 33 +--------- tests/motors/test_feetech.py | 53 +++++----------- tests/motors/test_motors_bus.py | 66 ++++++++++++++++++-- 8 files changed, 129 insertions(+), 115 deletions(-) diff --git a/lerobot/common/motors/dynamixel/dynamixel.py b/lerobot/common/motors/dynamixel/dynamixel.py index a710afdec..1ebefac07 100644 --- a/lerobot/common/motors/dynamixel/dynamixel.py +++ b/lerobot/common/motors/dynamixel/dynamixel.py @@ -84,6 +84,23 @@ class TorqueMode(Enum): DISABLED = 0 +def _split_into_byte_chunks(value: int, length: int) -> list[int]: + import dynamixel_sdk as dxl + + if length == 1: + data = [value] + elif length == 2: + data = [dxl.DXL_LOBYTE(value), dxl.DXL_HIBYTE(value)] + elif length == 4: + data = [ + dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)), + dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)), + dxl.DXL_LOBYTE(dxl.DXL_HIWORD(value)), + dxl.DXL_HIBYTE(dxl.DXL_HIWORD(value)), + ] + return data + + class DynamixelMotorsBus(MotorsBus): """ The Dynamixel implementation for a MotorsBus. It relies on the python dynamixel sdk to communicate with @@ -166,22 +183,8 @@ class DynamixelMotorsBus(MotorsBus): return half_turn_homings - @staticmethod - def _split_into_byte_chunks(value: int, length: int) -> list[int]: - import dynamixel_sdk as dxl - - if length == 1: - data = [value] - elif length == 2: - data = [dxl.DXL_LOBYTE(value), dxl.DXL_HIBYTE(value)] - elif length == 4: - data = [ - dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)), - dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)), - dxl.DXL_LOBYTE(dxl.DXL_HIWORD(value)), - dxl.DXL_HIBYTE(dxl.DXL_HIWORD(value)), - ] - return data + def _split_into_byte_chunks(self, value: int, length: int) -> list[int]: + return _split_into_byte_chunks(value, length) def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None: for n_try in range(1 + num_retry): diff --git a/lerobot/common/motors/feetech/feetech.py b/lerobot/common/motors/feetech/feetech.py index a6b0c380b..5e957f2f5 100644 --- a/lerobot/common/motors/feetech/feetech.py +++ b/lerobot/common/motors/feetech/feetech.py @@ -64,6 +64,23 @@ class TorqueMode(Enum): DISABLED = 0 +def _split_into_byte_chunks(value: int, length: int) -> list[int]: + import scservo_sdk as scs + + if length == 1: + data = [value] + elif length == 2: + data = [scs.SCS_LOBYTE(value), scs.SCS_HIBYTE(value)] + elif length == 4: + data = [ + scs.SCS_LOBYTE(scs.SCS_LOWORD(value)), + scs.SCS_HIBYTE(scs.SCS_LOWORD(value)), + scs.SCS_LOBYTE(scs.SCS_HIWORD(value)), + scs.SCS_HIBYTE(scs.SCS_HIWORD(value)), + ] + return data + + def patch_setPacketTimeout(self, packet_length): # noqa: N802 """ HACK: This patches the PortHandler behavior to set the correct packet timeouts. @@ -169,22 +186,8 @@ class FeetechMotorsBus(MotorsBus): return ids_values - @staticmethod - def _split_into_byte_chunks(value: int, length: int) -> list[int]: - import scservo_sdk as scs - - if length == 1: - data = [value] - elif length == 2: - data = [scs.SCS_LOBYTE(value), scs.SCS_HIBYTE(value)] - elif length == 4: - data = [ - scs.SCS_LOBYTE(scs.SCS_LOWORD(value)), - scs.SCS_HIBYTE(scs.SCS_LOWORD(value)), - scs.SCS_LOBYTE(scs.SCS_HIWORD(value)), - scs.SCS_HIBYTE(scs.SCS_HIWORD(value)), - ] - return data + def _split_into_byte_chunks(self, value: int, length: int) -> list[int]: + return _split_into_byte_chunks(value, length) def _broadcast_ping_p1( self, known_motors_only: bool = True, n_motors: int | None = None, num_retry: int = 0 diff --git a/lerobot/common/motors/motors_bus.py b/lerobot/common/motors/motors_bus.py index 7bc8a4ae0..efc81166e 100644 --- a/lerobot/common/motors/motors_bus.py +++ b/lerobot/common/motors/motors_bus.py @@ -664,9 +664,8 @@ class MotorsBus(abc.ABC): return self._split_into_byte_chunks(value, length) - @staticmethod @abc.abstractmethod - def _split_into_byte_chunks(value: int, length: int) -> list[int]: + def _split_into_byte_chunks(self, value: int, length: int) -> list[int]: """Convert an integer into a list of byte-sized integers.""" pass diff --git a/tests/mocks/mock_dynamixel.py b/tests/mocks/mock_dynamixel.py index feae051bb..1c1ab6fec 100644 --- a/tests/mocks/mock_dynamixel.py +++ b/tests/mocks/mock_dynamixel.py @@ -5,7 +5,8 @@ import dynamixel_sdk as dxl import serial from mock_serial.mock_serial import MockSerial -from lerobot.common.motors.dynamixel import X_SERIES_CONTROL_TABLE, DynamixelMotorsBus +from lerobot.common.motors.dynamixel import X_SERIES_CONTROL_TABLE +from lerobot.common.motors.dynamixel.dynamixel import _split_into_byte_chunks from .mock_serial_patch import WaitableStub @@ -237,7 +238,7 @@ class MockInstructionPacket(MockDynamixelPacketv2): +2 is for the length bytes, +2 is for the CRC at the end. """ - data = DynamixelMotorsBus._split_into_byte_chunks(value, data_length) + data = _split_into_byte_chunks(value, data_length) params = [ dxl.DXL_LOBYTE(start_address), dxl.DXL_HIBYTE(start_address), @@ -315,7 +316,7 @@ class MockInstructionPacket(MockDynamixelPacketv2): """ data = [] for id_, value in ids_values.items(): - split_value = DynamixelMotorsBus._split_into_byte_chunks(value, data_length) + split_value = _split_into_byte_chunks(value, data_length) data += [id_, *split_value] params = [ dxl.DXL_LOBYTE(start_address), @@ -389,7 +390,7 @@ class MockStatusPacket(MockDynamixelPacketv2): Returns: bytes: The raw 'Present_Position' status packet ready to be sent through serial. """ - params = DynamixelMotorsBus._split_into_byte_chunks(value, param_length) + params = _split_into_byte_chunks(value, param_length) length = param_length + 4 return cls.build(dxl_id, params=params, length=length) diff --git a/tests/mocks/mock_feetech.py b/tests/mocks/mock_feetech.py index 57bd8cbc7..2b54ae91f 100644 --- a/tests/mocks/mock_feetech.py +++ b/tests/mocks/mock_feetech.py @@ -5,8 +5,8 @@ import scservo_sdk as scs import serial from mock_serial import MockSerial -from lerobot.common.motors.feetech import STS_SMS_SERIES_CONTROL_TABLE, FeetechMotorsBus -from lerobot.common.motors.feetech.feetech import patch_setPacketTimeout +from lerobot.common.motors.feetech import STS_SMS_SERIES_CONTROL_TABLE +from lerobot.common.motors.feetech.feetech import _split_into_byte_chunks, patch_setPacketTimeout from .mock_serial_patch import WaitableStub @@ -139,7 +139,7 @@ class MockInstructionPacket(MockFeetechPacket): +1 is for the length bytes, +1 is for the checksum at the end. """ - data = FeetechMotorsBus._split_into_byte_chunks(value, data_length) + data = _split_into_byte_chunks(value, data_length) params = [start_address, *data] length = data_length + 3 return cls.build(scs_id=scs_id, params=params, length=length, instruct_type="Write") @@ -201,7 +201,7 @@ class MockInstructionPacket(MockFeetechPacket): """ data = [] for id_, value in ids_values.items(): - split_value = FeetechMotorsBus._split_into_byte_chunks(value, data_length) + split_value = _split_into_byte_chunks(value, data_length) data += [id_, *split_value] params = [start_address, data_length, *data] length = len(ids_values) * (1 + data_length) + 4 @@ -258,7 +258,7 @@ class MockStatusPacket(MockFeetechPacket): Returns: bytes: The raw 'Sync Read' status packet ready to be sent through serial. """ - params = FeetechMotorsBus._split_into_byte_chunks(value, param_length) + params = _split_into_byte_chunks(value, param_length) length = param_length + 2 return cls.build(scs_id, params=params, length=length) diff --git a/tests/motors/test_dynamixel.py b/tests/motors/test_dynamixel.py index e047e7c1c..2b7088360 100644 --- a/tests/motors/test_dynamixel.py +++ b/tests/motors/test_dynamixel.py @@ -67,43 +67,16 @@ def test_autouse_patch(): (0x12, 1, [0x12]), (0x1234, 2, [0x34, 0x12]), (0x12345678, 4, [0x78, 0x56, 0x34, 0x12]), - (0, 1, [0x00]), - (0, 2, [0x00, 0x00]), - (0, 4, [0x00, 0x00, 0x00, 0x00]), - (255, 1, [0xFF]), - (65535, 2, [0xFF, 0xFF]), - (4294967295, 4, [0xFF, 0xFF, 0xFF, 0xFF]), ], ids=[ "1 byte", "2 bytes", "4 bytes", - "0 with 1 byte", - "0 with 2 bytes", - "0 with 4 bytes", - "max single byte", - "max two bytes", - "max four bytes", ], ) # fmt: skip -def test_serialize_data(value, length, expected): - assert DynamixelMotorsBus._serialize_data(value, length) == expected - - -def test_serialize_data_invalid_length(): - with pytest.raises(NotImplementedError): - DynamixelMotorsBus._serialize_data(100, 3) - - -def test_serialize_data_negative_numbers(): - with pytest.raises(ValueError): - neg = DynamixelMotorsBus._serialize_data(-1, 1) - print(neg) - - -def test_serialize_data_large_number(): - with pytest.raises(ValueError): - DynamixelMotorsBus._serialize_data(2**32, 4) # 4-byte max is 0xFFFFFFFF +def test__split_into_byte_chunks(value, length, expected): + bus = DynamixelMotorsBus("", {}) + assert bus._split_into_byte_chunks(value, length) == expected def test_abc_implementation(dummy_motors): diff --git a/tests/motors/test_feetech.py b/tests/motors/test_feetech.py index da8194646..2d3d4db77 100644 --- a/tests/motors/test_feetech.py +++ b/tests/motors/test_feetech.py @@ -61,48 +61,27 @@ def test_autouse_patch(): @pytest.mark.parametrize( - "value, length, expected", + "protocol, value, length, expected", [ - (0x12, 1, [0x12]), - (0x1234, 2, [0x34, 0x12]), - (0x12345678, 4, [0x78, 0x56, 0x34, 0x12]), - (0, 1, [0x00]), - (0, 2, [0x00, 0x00]), - (0, 4, [0x00, 0x00, 0x00, 0x00]), - (255, 1, [0xFF]), - (65535, 2, [0xFF, 0xFF]), - (4294967295, 4, [0xFF, 0xFF, 0xFF, 0xFF]), + (0, 0x12, 1, [0x12]), + (1, 0x12, 1, [0x12]), + (0, 0x1234, 2, [0x34, 0x12]), + (1, 0x1234, 2, [0x12, 0x34]), + (0, 0x12345678, 4, [0x78, 0x56, 0x34, 0x12]), + (1, 0x12345678, 4, [0x56, 0x78, 0x12, 0x34]), ], ids=[ - "1 byte", - "2 bytes", - "4 bytes", - "0 with 1 byte", - "0 with 2 bytes", - "0 with 4 bytes", - "max single byte", - "max two bytes", - "max four bytes", + "P0: 1 byte", + "P1: 1 byte", + "P0: 2 bytes", + "P1: 2 bytes", + "P0: 4 bytes", + "P1: 4 bytes", ], ) # fmt: skip -def test_serialize_data(value, length, expected): - assert FeetechMotorsBus._serialize_data(value, length) == expected - - -def test_serialize_data_invalid_length(): - with pytest.raises(NotImplementedError): - FeetechMotorsBus._serialize_data(100, 3) - - -def test_serialize_data_negative_numbers(): - with pytest.raises(ValueError): - neg = FeetechMotorsBus._serialize_data(-1, 1) - print(neg) - - -def test_serialize_data_large_number(): - with pytest.raises(ValueError): - FeetechMotorsBus._serialize_data(2**32, 4) # 4-byte max is 0xFFFFFFFF +def test__split_into_byte_chunks(protocol, value, length, expected): + bus = FeetechMotorsBus("", {}, protocol_version=protocol) + assert bus._split_into_byte_chunks(value, length) == expected def test_abc_implementation(dummy_motors): diff --git a/tests/motors/test_motors_bus.py b/tests/motors/test_motors_bus.py index 7463ae8c3..8ceaeefab 100644 --- a/tests/motors/test_motors_bus.py +++ b/tests/motors/test_motors_bus.py @@ -2,12 +2,50 @@ import re import pytest -from lerobot.common.motors.motors_bus import assert_same_address, get_address, get_ctrl_table +from lerobot.common.motors.motors_bus import ( + Motor, + MotorsBus, + assert_same_address, + get_address, + get_ctrl_table, +) -# TODO(aliberts) -# class DummyMotorsBus(MotorsBus): -# def __init__(self, port: str, motors: dict[str, Motor]): -# super().__init__(port, motors) +DUMMY_CTRL_TABLE = {"Present_Position": (13, 4)} + +DUMMY_BAUDRATE_TABLE = { + 0: 1_000_000, + 1: 500_000, +} + +DUMMY_ENCODING_TABLE = { + "Present_Position": 8, +} + +DUMMY_MODEL_NUMBER_TABLE = {""} + + +class DummyMotorsBus(MotorsBus): + available_baudrates = [1_000_000] + default_timeout = 1000 + model_baudrate_table = {"model": DUMMY_BAUDRATE_TABLE} + model_ctrl_table = {"model": DUMMY_CTRL_TABLE} + model_encoding_table = {"model": DUMMY_ENCODING_TABLE} + model_number_table = {"model": 1234} + model_resolution_table = {"model": 4096} + normalized_data = ["Present_Position"] + + def __init__(self, port: str, motors: dict[str, Motor]): + super().__init__(port, motors) + + def _assert_protocol_is_compatible(self, instruction_name): ... + def configure_motors(self): ... + def disable_torque(self, motors): ... + def enable_torque(self, motors): ... + def _get_half_turn_homings(self, positions): ... + def _encode_sign(self, data_name, ids_values): ... + def _decode_sign(self, data_name, ids_values): ... + def _split_into_byte_chunks(self, value, length): ... + def broadcast_ping(self, num_retry, raise_on_error): ... @pytest.fixture @@ -85,3 +123,21 @@ def test_assert_same_address_different_bytes(model_ctrl_table): match=re.escape("At least two motor models use a different bytes representation"), ): assert_same_address(model_ctrl_table, models, "Goal_Position") + + +def test__serialize_data_invalid_length(): + bus = DummyMotorsBus("", {}) + with pytest.raises(NotImplementedError): + bus._serialize_data(100, 3) + + +def test__serialize_data_negative_numbers(): + bus = DummyMotorsBus("", {}) + with pytest.raises(ValueError): + bus._serialize_data(-1, 1) + + +def test__serialize_data_large_number(): + bus = DummyMotorsBus("", {}) + with pytest.raises(ValueError): + bus._serialize_data(2**32, 4) # 4-byte max is 0xFFFFFFFF From f960f4d8d456d8effd9acf4a2320ae02569957b4 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Fri, 11 Apr 2025 11:58:31 +0200 Subject: [PATCH 15/23] Fix unormalize --- lerobot/common/motors/motors_bus.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lerobot/common/motors/motors_bus.py b/lerobot/common/motors/motors_bus.py index efc81166e..cada33a7c 100644 --- a/lerobot/common/motors/motors_bus.py +++ b/lerobot/common/motors/motors_bus.py @@ -600,6 +600,7 @@ class MotorsBus(abc.ABC): def _normalize(self, data_name: str, ids_values: dict[int, int]) -> dict[int, float]: if not self.calibration: raise RuntimeError(f"{self} has no calibration registered.") + normalized_values = {} for id_, val in ids_values.items(): name = self._id_to_name(id_) @@ -617,6 +618,9 @@ class MotorsBus(abc.ABC): return normalized_values def _unnormalize(self, data_name: str, ids_values: dict[int, float]) -> dict[int, int]: + if not self.calibration: + raise RuntimeError(f"{self} has no calibration registered.") + unnormalized_values = {} for id_, val in ids_values.items(): name = self._id_to_name(id_) From e0b292ab519338a8e61e6195d69af6c73381c2ee Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Fri, 11 Apr 2025 12:24:30 +0200 Subject: [PATCH 16/23] Remove test_motors_bus fixtures --- tests/motors/test_motors_bus.py | 108 ++++++++++++++++---------------- 1 file changed, 53 insertions(+), 55 deletions(-) diff --git a/tests/motors/test_motors_bus.py b/tests/motors/test_motors_bus.py index 8ceaeefab..7797622ee 100644 --- a/tests/motors/test_motors_bus.py +++ b/tests/motors/test_motors_bus.py @@ -10,29 +10,56 @@ from lerobot.common.motors.motors_bus import ( get_ctrl_table, ) -DUMMY_CTRL_TABLE = {"Present_Position": (13, 4)} +DUMMY_CTRL_TABLE_1 = { + "Firmware_Version": (0, 1), + "Model_Number": (1, 2), + "Present_Position": (3, 4), + "Goal_Position": (7, 2), +} + +DUMMY_CTRL_TABLE_2 = { + "Model_Number": (0, 2), + "Firmware_Version": (2, 1), + "Present_Position": (3, 4), + "Goal_Position": (7, 4), + "Lock": (7, 4), +} + +DUMMY_MODEL_CTRL_TABLE = { + "model_1": DUMMY_CTRL_TABLE_1, + "model_2": DUMMY_CTRL_TABLE_2, +} DUMMY_BAUDRATE_TABLE = { 0: 1_000_000, 1: 500_000, } -DUMMY_ENCODING_TABLE = { - "Present_Position": 8, +DUMMY_MODEL_BAUDRATE_TABLE = { + "model_1": DUMMY_BAUDRATE_TABLE, + "model_2": DUMMY_BAUDRATE_TABLE, } -DUMMY_MODEL_NUMBER_TABLE = {""} +DUMMY_ENCODING_TABLE = { + "Present_Position": 8, + "Goal_Position": 10, +} + +DUMMY_MODEL_ENCODING_TABLE = { + "model_1": DUMMY_ENCODING_TABLE, + "model_2": DUMMY_ENCODING_TABLE, +} class DummyMotorsBus(MotorsBus): - available_baudrates = [1_000_000] + available_baudrates = [500_000, 1_000_000] default_timeout = 1000 - model_baudrate_table = {"model": DUMMY_BAUDRATE_TABLE} - model_ctrl_table = {"model": DUMMY_CTRL_TABLE} - model_encoding_table = {"model": DUMMY_ENCODING_TABLE} - model_number_table = {"model": 1234} - model_resolution_table = {"model": 4096} - normalized_data = ["Present_Position"] + model_baudrate_table = DUMMY_MODEL_BAUDRATE_TABLE + model_ctrl_table = DUMMY_MODEL_CTRL_TABLE + model_encoding_table = DUMMY_MODEL_ENCODING_TABLE + model_number_table = {"model_1": 1234, "model_2": 5678} + model_resolution_table = {"model_1": 4096, "model_2": 1024} + normalized_data = ["Present_Position", "Goal_Position"] def __init__(self, port: str, motors: dict[str, Motor]): super().__init__(port, motors) @@ -48,81 +75,52 @@ class DummyMotorsBus(MotorsBus): def broadcast_ping(self, num_retry, raise_on_error): ... -@pytest.fixture -def ctrl_table_1() -> dict: - return { - "Firmware_Version": (0, 1), - "Model_Number": (1, 2), - "Present_Position": (3, 4), - "Goal_Position": (7, 2), - } - - -@pytest.fixture -def ctrl_table_2() -> dict: - return { - "Model_Number": (0, 2), - "Firmware_Version": (2, 1), - "Present_Position": (3, 4), - "Goal_Position": (7, 4), - "Lock": (7, 4), - } - - -@pytest.fixture -def model_ctrl_table(ctrl_table_1, ctrl_table_2) -> dict: - return { - "model_1": ctrl_table_1, - "model_2": ctrl_table_2, - } - - -def test_get_ctrl_table(model_ctrl_table, ctrl_table_1): +def test_get_ctrl_table(): model = "model_1" - ctrl_table = get_ctrl_table(model_ctrl_table, model) - assert ctrl_table == ctrl_table_1 + ctrl_table = get_ctrl_table(DUMMY_MODEL_CTRL_TABLE, model) + assert ctrl_table == DUMMY_CTRL_TABLE_1 -def test_get_ctrl_table_error(model_ctrl_table): +def test_get_ctrl_table_error(): model = "model_99" with pytest.raises(KeyError, match=f"Control table for {model=} not found."): - get_ctrl_table(model_ctrl_table, model) + get_ctrl_table(DUMMY_MODEL_CTRL_TABLE, model) -def test_get_address(model_ctrl_table): - addr, n_bytes = get_address(model_ctrl_table, "model_1", "Firmware_Version") +def test_get_address(): + addr, n_bytes = get_address(DUMMY_MODEL_CTRL_TABLE, "model_1", "Firmware_Version") assert addr == 0 assert n_bytes == 1 -def test_get_address_error(model_ctrl_table): +def test_get_address_error(): model = "model_1" data_name = "Lock" with pytest.raises(KeyError, match=f"Address for '{data_name}' not found in {model} control table."): - get_address(model_ctrl_table, "model_1", data_name) + get_address(DUMMY_MODEL_CTRL_TABLE, "model_1", data_name) -def test_assert_same_address(model_ctrl_table): +def test_assert_same_address(): models = ["model_1", "model_2"] - assert_same_address(model_ctrl_table, models, "Present_Position") + assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Present_Position") -def test_assert_same_address_different_addresses(model_ctrl_table): +def test_assert_same_address_different_addresses(): models = ["model_1", "model_2"] with pytest.raises( NotImplementedError, match=re.escape("At least two motor models use a different address"), ): - assert_same_address(model_ctrl_table, models, "Model_Number") + assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Model_Number") -def test_assert_same_address_different_bytes(model_ctrl_table): +def test_assert_same_address_different_bytes(): models = ["model_1", "model_2"] with pytest.raises( NotImplementedError, match=re.escape("At least two motor models use a different bytes representation"), ): - assert_same_address(model_ctrl_table, models, "Goal_Position") + assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Goal_Position") def test__serialize_data_invalid_length(): From bdbca09cb2a8446987a864571ba42031238c86d1 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Mon, 14 Apr 2025 11:56:53 +0200 Subject: [PATCH 17/23] Add more segmented tests (base motor bus & feetech), add feetech protocol 1 support --- lerobot/common/motors/feetech/feetech.py | 101 ++++--- lerobot/common/motors/feetech/tables.py | 31 +- lerobot/common/motors/motors_bus.py | 114 ++++--- tests/mocks/mock_feetech.py | 144 +++++---- tests/motors/test_feetech.py | 363 +++++++++++++---------- tests/motors/test_motors_bus.py | 346 ++++++++++++++++++++- 6 files changed, 749 insertions(+), 350 deletions(-) diff --git a/lerobot/common/motors/feetech/feetech.py b/lerobot/common/motors/feetech/feetech.py index 5e957f2f5..193f1b4a0 100644 --- a/lerobot/common/motors/feetech/feetech.py +++ b/lerobot/common/motors/feetech/feetech.py @@ -21,12 +21,13 @@ from lerobot.common.utils.encoding_utils import decode_sign_magnitude, encode_si from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value from .tables import ( - FIRMWARE_VERSION, + FIRMWARE_MAJOR_VERSION, MODEL_BAUDRATE_TABLE, MODEL_CONTROL_TABLE, MODEL_ENCODING_TABLE, MODEL_NUMBER, MODEL_NUMBER_TABLE, + MODEL_PROTOCOL, MODEL_RESOLUTION, SCAN_BAUDRATES, ) @@ -117,9 +118,10 @@ class FeetechMotorsBus(MotorsBus): protocol_version: int = DEFAULT_PROTOCOL_VERSION, ): super().__init__(port, motors, calibration) + self.protocol_version = protocol_version + self._assert_same_protocol() import scservo_sdk as scs - self.protocol_version = protocol_version self.port_handler = scs.PortHandler(self.port) # HACK: monkeypatch self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( @@ -131,10 +133,21 @@ class FeetechMotorsBus(MotorsBus): self._comm_success = scs.COMM_SUCCESS self._no_error = 0x00 + if any(MODEL_PROTOCOL[model] != self.protocol_version for model in self.models): + raise ValueError(f"Some motors are incompatible with protocol_version={self.protocol_version}") + + def _assert_same_protocol(self) -> None: + if any(MODEL_PROTOCOL[model] != self.protocol_version for model in self.models): + raise RuntimeError("Some motors use an incompatible protocol.") + def _assert_protocol_is_compatible(self, instruction_name: str) -> None: if instruction_name == "sync_read" and self.protocol_version == 1: raise NotImplementedError( - "'Sync Read' is not available with Feetech motors using Protocol 1. Use 'Read' instead." + "'Sync Read' is not available with Feetech motors using Protocol 1. Use 'Read' sequentially instead." + ) + if instruction_name == "broadcast_ping" and self.protocol_version == 1: + raise NotImplementedError( + "'Broadcast Ping' is not available with Feetech motors using Protocol 1. Use 'Ping' sequentially instead." ) def configure_motors(self) -> None: @@ -157,12 +170,12 @@ class FeetechMotorsBus(MotorsBus): return half_turn_homings def disable_torque(self, motors: str | list[str] | None = None) -> None: - for name in self._get_names_list(motors): + for name in self._get_motors_list(motors): self.write("Torque_Enable", name, TorqueMode.DISABLED.value) self.write("Lock", name, 0) def enable_torque(self, motors: str | list[str] | None = None) -> None: - for name in self._get_names_list(motors): + for name in self._get_motors_list(motors): self.write("Torque_Enable", name, TorqueMode.ENABLED.value) self.write("Lock", name, 1) @@ -286,56 +299,52 @@ class FeetechMotorsBus(MotorsBus): rx_length = rx_length - idx def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None: - if self.protocol_version == 0: - for n_try in range(1 + num_retry): - ids_status, comm = self._broadcast_ping_p0() - if self._is_comm_success(comm): - break - logger.debug(f"Broadcast ping failed on port '{self.port}' ({n_try=})") - logger.debug(self.packet_handler.getTxRxResult(comm)) + self._assert_protocol_is_compatible("broadcast_ping") + for n_try in range(1 + num_retry): + ids_status, comm = self._broadcast_ping_p0() + if self._is_comm_success(comm): + break + logger.debug(f"Broadcast ping failed on port '{self.port}' ({n_try=})") + logger.debug(self.packet_handler.getTxRxResult(comm)) - if not self._is_comm_success(comm): - if raise_on_error: - raise ConnectionError(self.packet_handler.getTxRxResult(comm)) - return - - ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)} - if ids_errors: - display_dict = { - id_: self.packet_handler.getRxPacketError(err) for id_, err in ids_errors.items() - } - logger.error( - f"Some motors found returned an error status:\n{pformat(display_dict, indent=4)}" - ) - - return self._get_model_number(list(ids_status), raise_on_error) - else: - return self._broadcast_ping_p1(num_retry=num_retry) - - def _get_firmware_version(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]: - comm, firmware_versions = self._sync_read(*FIRMWARE_VERSION, motor_ids) if not self._is_comm_success(comm): if raise_on_error: raise ConnectionError(self.packet_handler.getTxRxResult(comm)) return + ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)} + if ids_errors: + display_dict = {id_: self.packet_handler.getRxPacketError(err) for id_, err in ids_errors.items()} + logger.error(f"Some motors found returned an error status:\n{pformat(display_dict, indent=4)}") + + return self._get_model_number(list(ids_status), raise_on_error) + + def _get_firmware_version(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, str]: + firmware_versions = {} + for id_ in motor_ids: + firm_ver_major, comm, error = self._read( + *FIRMWARE_MAJOR_VERSION, id_, raise_on_error=raise_on_error + ) + if not self._is_comm_success(comm) or self._is_error(error): + return + + firm_ver_minor, comm, error = self._read( + *FIRMWARE_MAJOR_VERSION, id_, raise_on_error=raise_on_error + ) + if not self._is_comm_success(comm) or self._is_error(error): + return + + firmware_versions[id_] = f"{firm_ver_major}.{firm_ver_minor}" + return firmware_versions def _get_model_number(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]: - if self.protocol_version == 1: - model_numbers = {} - for id_ in motor_ids: - model_nb, comm, error = self._read(*MODEL_NUMBER, id_) - if self._is_comm_success(comm) and not self._is_error(error): - model_numbers[id_] = model_nb - elif raise_on_error: - raise Exception # FIX - - else: - comm, model_numbers = self._sync_read(*MODEL_NUMBER, motor_ids) - if not self._is_comm_success(comm): - if raise_on_error: - raise ConnectionError(self.packet_handler.getTxRxResult(comm)) + model_numbers = {} + for id_ in motor_ids: + model_nb, comm, error = self._read(*MODEL_NUMBER, id_, raise_on_error=raise_on_error) + if not self._is_comm_success(comm) or self._is_error(error): return + model_numbers[id_] = model_nb + return model_numbers diff --git a/lerobot/common/motors/feetech/tables.py b/lerobot/common/motors/feetech/tables.py index 176033174..ada8d08fd 100644 --- a/lerobot/common/motors/feetech/tables.py +++ b/lerobot/common/motors/feetech/tables.py @@ -1,9 +1,5 @@ FIRMWARE_MAJOR_VERSION = (0, 1) FIRMWARE_MINOR_VERSION = (1, 1) -MODEL_MAJOR_VERSION = (3, 1) -MODEL_MINOR_VERSION = (4, 1) - -FIRMWARE_VERSION = (0, 2) MODEL_NUMBER = (3, 2) # See this link for STS3215 Memory Table: @@ -11,12 +7,9 @@ MODEL_NUMBER = (3, 2) # data_name: (address, size_byte) STS_SMS_SERIES_CONTROL_TABLE = { # EPROM - "Firmware_Version": FIRMWARE_VERSION, # read-only + "Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only + "Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only "Model_Number": MODEL_NUMBER, # read-only - # "Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only - # "Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only - # "Model_Major_Version": MODEL_MAJOR_VERSION, # read-only - # "Model_Minor_Version": MODEL_MINOR_VERSION, "ID": (5, 1), "Baud_Rate": (6, 1), "Return_Delay_Time": (7, 1), @@ -68,12 +61,9 @@ STS_SMS_SERIES_CONTROL_TABLE = { SCS_SERIES_CONTROL_TABLE = { # EPROM - "Firmware_Version": FIRMWARE_VERSION, # read-only + "Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only + "Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only "Model_Number": MODEL_NUMBER, # read-only - # "Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only - # "Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only - # "Model_Major_Version": MODEL_MAJOR_VERSION, # read-only - # "Model_Minor_Version": MODEL_MINOR_VERSION, "ID": (5, 1), "Baud_Rate": (6, 1), "Return_Delay": (7, 1), @@ -194,10 +184,19 @@ SCAN_BAUDRATES = [ 1_000_000, ] -# {model: model_number} TODO MODEL_NUMBER_TABLE = { "sts3215": 777, - "sts3250": None, + "sts3250": 2825, "sm8512bl": 11272, "scs0009": 1284, } + +MODEL_PROTOCOL = { + "sts_series": 0, + "sms_series": 0, + "scs_series": 1, + "sts3215": 0, + "sts3250": 0, + "sm8512bl": 0, + "scs0009": 1, +} diff --git a/lerobot/common/motors/motors_bus.py b/lerobot/common/motors/motors_bus.py index cada33a7c..d0f8ff3ed 100644 --- a/lerobot/common/motors/motors_bus.py +++ b/lerobot/common/motors/motors_bus.py @@ -283,6 +283,8 @@ class MotorsBus(abc.ABC): self._id_to_name_dict = {m.id: name for name, m in self.motors.items()} self._model_nb_to_model_dict = {v: k for k, v in self.model_number_table.items()} + self._validate_motors() + def __len__(self): return len(self.motors) @@ -341,7 +343,7 @@ class MotorsBus(abc.ABC): else: raise TypeError(f"'{motor}' should be int, str.") - def _get_names_list(self, motors: str | list[str] | None) -> list[str]: + def _get_motors_list(self, motors: str | list[str] | None) -> list[str]: if motors is None: return self.names elif isinstance(motors, str): @@ -422,8 +424,8 @@ class MotorsBus(abc.ABC): logger.debug(f"{self.__class__.__name__} connected.") @classmethod - def scan_port(cls, port: str) -> dict[int, list[int]]: - bus = cls(port, {}) + def scan_port(cls, port: str, *args, **kwargs) -> dict[int, list[int]]: + bus = cls(port, {}, *args, **kwargs) try: bus.port_handler.openPort() except (FileNotFoundError, OSError, serial.SerialException) as e: @@ -715,17 +717,8 @@ class MotorsBus(abc.ABC): model = self.motors[motor].model addr, length = get_address(self.model_ctrl_table, model, data_name) - value, comm, error = self._read(addr, length, id_, num_retry=num_retry) - if not self._is_comm_success(comm): - raise ConnectionError( - f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries." - f"{self.packet_handler.getTxRxResult(comm)}" - ) - elif self._is_error(error): - raise RuntimeError( - f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries." - f"\n{self.packet_handler.getRxPacketError(error)}" - ) + err_msg = f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries." + value, _, _ = self._read(addr, length, id_, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) id_value = self._decode_sign(data_name, {id_: value}) @@ -734,7 +727,16 @@ class MotorsBus(abc.ABC): return id_value[id_] - def _read(self, address: int, length: int, motor_id: int, num_retry: int = 0) -> tuple[int, int]: + def _read( + self, + address: int, + length: int, + motor_id: int, + *, + num_retry: int = 0, + raise_on_error: bool = True, + err_msg: str = "", + ) -> tuple[int, int]: if length == 1: read_fn = self.packet_handler.read1ByteTxRx elif length == 2: @@ -753,6 +755,11 @@ class MotorsBus(abc.ABC): + self.packet_handler.getTxRxResult(comm) ) + if not self._is_comm_success(comm) and raise_on_error: + raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}") + elif self._is_error(error) and raise_on_error: + raise RuntimeError(f"{err_msg} {self.packet_handler.getRxPacketError(error)}") + return value, comm, error def write( @@ -772,20 +779,19 @@ class MotorsBus(abc.ABC): value = self._encode_sign(data_name, {id_: value})[id_] - comm, error = self._write(addr, length, id_, value, num_retry=num_retry) - if not self._is_comm_success(comm): - raise ConnectionError( - f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries." - f"\n{self.packet_handler.getTxRxResult(comm)}" - ) - elif self._is_error(error): - raise RuntimeError( - f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries." - f"\n{self.packet_handler.getRxPacketError(error)}" - ) + err_msg = f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries." + self._write(addr, length, id_, value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) def _write( - self, addr: int, length: int, motor_id: int, value: int, num_retry: int = 0 + self, + addr: int, + length: int, + motor_id: int, + value: int, + *, + num_retry: int = 0, + raise_on_error: bool = True, + err_msg: str = "", ) -> tuple[int, int]: data = self._serialize_data(value, length) for n_try in range(1 + num_retry): @@ -797,6 +803,11 @@ class MotorsBus(abc.ABC): + self.packet_handler.getTxRxResult(comm) ) + if not self._is_comm_success(comm) and raise_on_error: + raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}") + elif self._is_error(error) and raise_on_error: + raise RuntimeError(f"{err_msg} {self.packet_handler.getRxPacketError(error)}") + return comm, error def sync_read( @@ -814,7 +825,7 @@ class MotorsBus(abc.ABC): self._assert_protocol_is_compatible("sync_read") - names = self._get_names_list(motors) + names = self._get_motors_list(motors) ids = [self.motors[name].id for name in names] models = [self.motors[name].model for name in names] @@ -824,12 +835,10 @@ class MotorsBus(abc.ABC): model = next(iter(models)) addr, length = get_address(self.model_ctrl_table, model, data_name) - comm, ids_values = self._sync_read(addr, length, ids, num_retry=num_retry) - if not self._is_comm_success(comm): - raise ConnectionError( - f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries." - f"{self.packet_handler.getTxRxResult(comm)}" - ) + err_msg = f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries." + ids_values, _ = self._sync_read( + addr, length, ids, num_retry=num_retry, raise_on_error=True, err_msg=err_msg + ) ids_values = self._decode_sign(data_name, ids_values) @@ -839,8 +848,15 @@ class MotorsBus(abc.ABC): return {self._id_to_name(id_): value for id_, value in ids_values.items()} def _sync_read( - self, addr: int, length: int, motor_ids: list[int], num_retry: int = 0 - ) -> tuple[int, dict[int, int]]: + self, + addr: int, + length: int, + motor_ids: list[int], + *, + num_retry: int = 0, + raise_on_error: bool = True, + err_msg: str = "", + ) -> tuple[dict[int, int], int]: self._setup_sync_reader(motor_ids, addr, length) for n_try in range(1 + num_retry): comm = self.sync_reader.txRxPacket() @@ -851,8 +867,11 @@ class MotorsBus(abc.ABC): + self.packet_handler.getTxRxResult(comm) ) + if not self._is_comm_success(comm) and raise_on_error: + raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}") + values = {id_: self.sync_reader.getData(id_, addr, length) for id_ in motor_ids} - return comm, values + return values, comm def _setup_sync_reader(self, motor_ids: list[int], addr: int, length: int) -> None: self.sync_reader.clearParam() @@ -901,14 +920,18 @@ class MotorsBus(abc.ABC): ids_values = self._encode_sign(data_name, ids_values) - comm = self._sync_write(addr, length, ids_values, num_retry=num_retry) - if not self._is_comm_success(comm): - raise ConnectionError( - f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries." - f"\n{self.packet_handler.getTxRxResult(comm)}" - ) + err_msg = f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries." + self._sync_write(addr, length, ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) - def _sync_write(self, addr: int, length: int, ids_values: dict[int, int], num_retry: int = 0) -> int: + def _sync_write( + self, + addr: int, + length: int, + ids_values: dict[int, int], + num_retry: int = 0, + raise_on_error: bool = True, + err_msg: str = "", + ) -> int: self._setup_sync_writer(ids_values, addr, length) for n_try in range(1 + num_retry): comm = self.sync_writer.txPacket() @@ -919,6 +942,9 @@ class MotorsBus(abc.ABC): + self.packet_handler.getTxRxResult(comm) ) + if not self._is_comm_success(comm) and raise_on_error: + raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}") + return comm def _setup_sync_writer(self, ids_values: dict[int, int], addr: int, length: int) -> None: diff --git a/tests/mocks/mock_feetech.py b/tests/mocks/mock_feetech.py index 2b54ae91f..f4bb1c686 100644 --- a/tests/mocks/mock_feetech.py +++ b/tests/mocks/mock_feetech.py @@ -10,27 +10,6 @@ from lerobot.common.motors.feetech.feetech import _split_into_byte_chunks, patch from .mock_serial_patch import WaitableStub -# https://files.waveshare.com/upload/2/27/Communication_Protocol_User_Manual-EN%28191218-0923%29.pdf -INSTRUCTION_TYPES = { - "Read": scs.INST_PING, # Read data from the Device - "Ping": scs.INST_READ, # Checks whether the Packet has arrived at a device with the same ID as the specified packet ID - "Write": scs.INST_WRITE, # Write data to the Device - "Reg_Write": scs.INST_REG_WRITE, # Register the Instruction Packet in standby status; Packet can later be executed using the Action command - "Action": scs.INST_ACTION, # Executes a Packet that was registered beforehand using Reg Write - "Factory_Reset": 0x06, # Resets the Control Table to its initial factory default settings - "Sync_Write": scs.INST_SYNC_WRITE, # Write data to multiple devices with the same Address with the same length at once - "Sync_Read": scs.INST_SYNC_READ, # Read data from multiple devices with the same Address with the same length at once -} # fmt: skip - -ERROR_TYPE = { - "Success": 0x00, - "Voltage": scs.ERRBIT_VOLTAGE, - "Angle": scs.ERRBIT_ANGLE, - "Overheat": scs.ERRBIT_OVERHEAT, - "Overele": scs.ERRBIT_OVERELE, - "Overload": scs.ERRBIT_OVERLOAD, -} - class MockFeetechPacket(abc.ABC): @classmethod @@ -68,15 +47,14 @@ class MockInstructionPacket(MockFeetechPacket): """ @classmethod - def _build(cls, scs_id: int, params: list[int], length: int, instruct_type: str) -> list[int]: - instruct_value = INSTRUCTION_TYPES[instruct_type] + def _build(cls, scs_id: int, params: list[int], length: int, instruction: int) -> list[int]: return [ - 0xFF, 0xFF, # header - scs_id, # servo id - length, # length - instruct_value, # instruction type - *params, # data bytes - 0x00, # placeholder for checksum + 0xFF, 0xFF, # header + scs_id, # servo id + length, # length + instruction, # instruction type + *params, # data bytes + 0x00, # placeholder for checksum ] # fmt: skip @classmethod @@ -89,7 +67,7 @@ class MockInstructionPacket(MockFeetechPacket): No parameters required. """ - return cls.build(scs_id=scs_id, params=[], length=2, instruct_type="Ping") + return cls.build(scs_id=scs_id, params=[], length=2, instruction=scs.INST_PING) @classmethod def read( @@ -113,7 +91,7 @@ class MockInstructionPacket(MockFeetechPacket): """ params = [start_address, data_length] length = 4 - return cls.build(scs_id=scs_id, params=params, length=length, instruct_type="Read") + return cls.build(scs_id=scs_id, params=params, length=length, instruction=scs.INST_READ) @classmethod def write( @@ -142,7 +120,7 @@ class MockInstructionPacket(MockFeetechPacket): data = _split_into_byte_chunks(value, data_length) params = [start_address, *data] length = data_length + 3 - return cls.build(scs_id=scs_id, params=params, length=length, instruct_type="Write") + return cls.build(scs_id=scs_id, params=params, length=length, instruction=scs.INST_WRITE) @classmethod def sync_read( @@ -167,7 +145,9 @@ class MockInstructionPacket(MockFeetechPacket): """ params = [start_address, data_length, *scs_ids] length = len(scs_ids) + 4 - return cls.build(scs_id=scs.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Read") + return cls.build( + scs_id=scs.BROADCAST_ID, params=params, length=length, instruction=scs.INST_SYNC_READ + ) @classmethod def sync_write( @@ -205,7 +185,9 @@ class MockInstructionPacket(MockFeetechPacket): data += [id_, *split_value] params = [start_address, data_length, *data] length = len(ids_values) * (1 + data_length) + 4 - return cls.build(scs_id=scs.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Write") + return cls.build( + scs_id=scs.BROADCAST_ID, params=params, length=length, instruction=scs.INST_SYNC_WRITE + ) class MockStatusPacket(MockFeetechPacket): @@ -222,19 +204,18 @@ class MockStatusPacket(MockFeetechPacket): """ @classmethod - def _build(cls, scs_id: int, params: list[int], length: int, error: str = "Success") -> list[int]: - err_byte = ERROR_TYPE[error] + def _build(cls, scs_id: int, params: list[int], length: int, error: int = 0) -> list[int]: return [ 0xFF, 0xFF, # header scs_id, # servo id length, # length - err_byte, # status + error, # status *params, # data bytes 0x00, # placeholder for checksum ] # fmt: skip @classmethod - def ping(cls, scs_id: int, error: str = "Success") -> bytes: + def ping(cls, scs_id: int, error: int = 0) -> bytes: """Builds a 'Ping' status packet. Args: @@ -247,7 +228,7 @@ class MockStatusPacket(MockFeetechPacket): return cls.build(scs_id, params=[], length=2, error=error) @classmethod - def read(cls, scs_id: int, value: int, param_length: int) -> bytes: + def read(cls, scs_id: int, value: int, param_length: int, error: int = 0) -> bytes: """Builds a 'Read' status packet. Args: @@ -260,7 +241,7 @@ class MockStatusPacket(MockFeetechPacket): """ params = _split_into_byte_chunks(value, param_length) length = param_length + 2 - return cls.build(scs_id, params=params, length=length) + return cls.build(scs_id, params=params, length=length, error=error) class MockPortHandler(scs.PortHandler): @@ -323,11 +304,11 @@ class MockMotors(MockSerial): ) return stub_name - def build_ping_stub(self, scs_id: int, num_invalid_try: int = 0) -> str: + def build_ping_stub(self, scs_id: int, num_invalid_try: int = 0, error: int = 0) -> str: ping_request = MockInstructionPacket.ping(scs_id) - return_packet = MockStatusPacket.ping(scs_id) + return_packet = MockStatusPacket.ping(scs_id, error) ping_response = self._build_send_fn(return_packet, num_invalid_try) - stub_name = f"Ping_{scs_id}" + stub_name = f"Ping_{scs_id}_{error}" self.stub( name=stub_name, receive_bytes=ping_request, @@ -336,13 +317,19 @@ class MockMotors(MockSerial): return stub_name def build_read_stub( - self, data_name: str, scs_id: int, value: int | None = None, num_invalid_try: int = 0 + self, + address: int, + length: int, + scs_id: int, + value: int, + reply: bool = True, + error: int = 0, + num_invalid_try: int = 0, ) -> str: - address, length = self.ctrl_table[data_name] read_request = MockInstructionPacket.read(scs_id, address, length) - return_packet = MockStatusPacket.read(scs_id, value, length) + return_packet = MockStatusPacket.read(scs_id, value, length, error) if reply else b"" read_response = self._build_send_fn(return_packet, num_invalid_try) - stub_name = f"Read_{data_name}_{scs_id}" + stub_name = f"Read_{address}_{length}_{scs_id}_{value}_{error}" self.stub( name=stub_name, receive_bytes=read_request, @@ -350,15 +337,42 @@ class MockMotors(MockSerial): ) return stub_name - def build_sync_read_stub( - self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0 + def build_write_stub( + self, + address: int, + length: int, + scs_id: int, + value: int, + reply: bool = True, + error: int = 0, + num_invalid_try: int = 0, ) -> str: - address, length = self.ctrl_table[data_name] - sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length) - return_packets = b"".join(MockStatusPacket.read(id_, pos, length) for id_, pos in ids_values.items()) + sync_read_request = MockInstructionPacket.write(scs_id, value, address, length) + return_packet = MockStatusPacket.build(scs_id, params=[], length=2, error=error) if reply else b"" + stub_name = f"Write_{address}_{length}_{scs_id}" + self.stub( + name=stub_name, + receive_bytes=sync_read_request, + send_fn=self._build_send_fn(return_packet, num_invalid_try), + ) + return stub_name + def build_sync_read_stub( + self, + address: int, + length: int, + ids_values: dict[int, int], + reply: bool = True, + num_invalid_try: int = 0, + ) -> str: + sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length) + return_packets = ( + b"".join(MockStatusPacket.read(id_, pos, length) for id_, pos in ids_values.items()) + if reply + else b"" + ) sync_read_response = self._build_send_fn(return_packets, num_invalid_try) - stub_name = f"Sync_Read_{data_name}_" + "_".join([str(id_) for id_ in ids_values]) + stub_name = f"Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values]) self.stub( name=stub_name, receive_bytes=sync_read_request, @@ -367,11 +381,10 @@ class MockMotors(MockSerial): return stub_name def build_sequential_sync_read_stub( - self, data_name: str, ids_values: dict[int, list[int]] | None = None + self, address: int, length: int, ids_values: dict[int, list[int]] | None = None ) -> str: sequence_length = len(next(iter(ids_values.values()))) assert all(len(positions) == sequence_length for positions in ids_values.values()) - address, length = self.ctrl_table[data_name] sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length) sequential_packets = [] for count in range(sequence_length): @@ -381,7 +394,7 @@ class MockMotors(MockSerial): sequential_packets.append(return_packets) sync_read_response = self._build_sequential_send_fn(sequential_packets) - stub_name = f"Seq_Sync_Read_{data_name}_" + "_".join([str(id_) for id_ in ids_values]) + stub_name = f"Seq_Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values]) self.stub( name=stub_name, receive_bytes=sync_read_request, @@ -390,11 +403,10 @@ class MockMotors(MockSerial): return stub_name def build_sync_write_stub( - self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0 + self, address: int, length: int, ids_values: dict[int, int], num_invalid_try: int = 0 ) -> str: - address, length = self.ctrl_table[data_name] sync_read_request = MockInstructionPacket.sync_write(ids_values, address, length) - stub_name = f"Sync_Write_{data_name}_" + "_".join([str(id_) for id_ in ids_values]) + stub_name = f"Sync_Write_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values]) self.stub( name=stub_name, receive_bytes=sync_read_request, @@ -402,20 +414,6 @@ class MockMotors(MockSerial): ) return stub_name - def build_write_stub( - self, data_name: str, scs_id: int, value: int, error: str = "Success", num_invalid_try: int = 0 - ) -> str: - address, length = self.ctrl_table[data_name] - sync_read_request = MockInstructionPacket.write(scs_id, value, address, length) - return_packet = MockStatusPacket.build(scs_id, params=[], length=2, error=error) - stub_name = f"Write_{data_name}_{scs_id}" - self.stub( - name=stub_name, - receive_bytes=sync_read_request, - send_fn=self._build_send_fn(return_packet, num_invalid_try), - ) - return stub_name - @staticmethod def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]: def send_fn(_call_count: int) -> bytes: diff --git a/tests/motors/test_feetech.py b/tests/motors/test_feetech.py index 2d3d4db77..d25b98bc6 100644 --- a/tests/motors/test_feetech.py +++ b/tests/motors/test_feetech.py @@ -1,3 +1,4 @@ +import re import sys from typing import Generator from unittest.mock import MagicMock, patch @@ -6,7 +7,8 @@ import pytest import scservo_sdk as scs from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode -from lerobot.common.motors.feetech import MODEL_NUMBER_TABLE, FeetechMotorsBus +from lerobot.common.motors.feetech import MODEL_NUMBER, MODEL_NUMBER_TABLE, FeetechMotorsBus +from lerobot.common.motors.feetech.tables import STS_SMS_SERIES_CONTROL_TABLE from lerobot.common.utils.encoding_utils import encode_sign_magnitude from tests.mocks.mock_feetech import MockMotors, MockPortHandler @@ -109,8 +111,9 @@ def test_scan_port(mock_motors): @pytest.mark.parametrize("id_", [1, 2, 3]) def test_ping(id_, mock_motors, dummy_motors): expected_model_nb = MODEL_NUMBER_TABLE[dummy_motors[f"dummy_{id_}"].model] + addr, length = MODEL_NUMBER ping_stub = mock_motors.build_ping_stub(id_) - mobel_nb_stub = mock_motors.build_read_stub("Model_Number", id_, expected_model_nb) + mobel_nb_stub = mock_motors.build_read_stub(addr, length, id_, expected_model_nb) motors_bus = FeetechMotorsBus( port=mock_motors.port, motors=dummy_motors, @@ -126,9 +129,15 @@ def test_ping(id_, mock_motors, dummy_motors): def test_broadcast_ping(mock_motors, dummy_motors): models = {m.id: m.model for m in dummy_motors.values()} - expected_model_nbs = {id_: MODEL_NUMBER_TABLE[model] for id_, model in models.items()} + addr, length = MODEL_NUMBER ping_stub = mock_motors.build_broadcast_ping_stub(list(models)) - mobel_nb_stub = mock_motors.build_sync_read_stub("Model_Number", expected_model_nbs) + mobel_nb_stubs = [] + expected_model_nbs = {} + for id_, model in models.items(): + model_nb = MODEL_NUMBER_TABLE[model] + stub = mock_motors.build_read_stub(addr, length, id_, model_nb) + expected_model_nbs[id_] = model_nb + mobel_nb_stubs.append(stub) motors_bus = FeetechMotorsBus( port=mock_motors.port, motors=dummy_motors, @@ -139,187 +148,209 @@ def test_broadcast_ping(mock_motors, dummy_motors): assert ping_model_nbs == expected_model_nbs assert mock_motors.stubs[ping_stub].called - assert mock_motors.stubs[mobel_nb_stub].called - - -def test_sync_read_none(mock_motors, dummy_motors): - expected_positions = { - "dummy_1": 1337, - "dummy_2": 42, - "dummy_3": 4016, - } - ids_values = dict(zip([1, 2, 3], expected_positions.values(), strict=True)) - stub_name = mock_motors.build_sync_read_stub("Present_Position", ids_values) - motors_bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) - - read_positions = motors_bus.sync_read("Present_Position", normalize=False) - - assert mock_motors.stubs[stub_name].called - assert read_positions == expected_positions + assert all(mock_motors.stubs[stub].called for stub in mobel_nb_stubs) @pytest.mark.parametrize( - "id_, position", + "addr, length, id_, value", [ - (1, 1337), - (2, 42), - (3, 4016), + (0, 1, 1, 2), + (10, 2, 2, 999), + (42, 4, 3, 1337), ], ) -def test_sync_read_single_value(id_, position, mock_motors, dummy_motors): - expected_position = {f"dummy_{id_}": position} - stub_name = mock_motors.build_sync_read_stub("Present_Position", {id_: position}) +def test__read(addr, length, id_, value, mock_motors, dummy_motors): + stub_name = mock_motors.build_read_stub(addr, length, id_, value) motors_bus = FeetechMotorsBus( port=mock_motors.port, motors=dummy_motors, ) motors_bus.connect(assert_motors_exist=False) - read_position = motors_bus.sync_read("Present_Position", f"dummy_{id_}", normalize=False) + read_value, _, _ = motors_bus._read(addr, length, id_) assert mock_motors.stubs[stub_name].called - assert read_position == expected_position + assert read_value == value -@pytest.mark.parametrize( - "ids, positions", - [ - ([1], [1337]), - ([1, 2], [1337, 42]), - ([1, 2, 3], [1337, 42, 4016]), - ], - ids=["1 motor", "2 motors", "3 motors"], -) # fmt: skip -def test_sync_read(ids, positions, mock_motors, dummy_motors): - assert len(ids) == len(positions) - names = [f"dummy_{dxl_id}" for dxl_id in ids] - expected_positions = dict(zip(names, positions, strict=True)) - ids_values = dict(zip(ids, positions, strict=True)) - stub_name = mock_motors.build_sync_read_stub("Present_Position", ids_values) +@pytest.mark.parametrize("raise_on_error", (True, False)) +def test__read_error(raise_on_error, mock_motors, dummy_motors): + addr, length, id_, value, error = (10, 4, 1, 1337, scs.ERRBIT_VOLTAGE) + stub_name = mock_motors.build_read_stub(addr, length, id_, value, error=error) motors_bus = FeetechMotorsBus( port=mock_motors.port, motors=dummy_motors, ) motors_bus.connect(assert_motors_exist=False) - read_positions = motors_bus.sync_read("Present_Position", names, normalize=False) - - assert mock_motors.stubs[stub_name].called - assert read_positions == expected_positions - - -@pytest.mark.parametrize( - "num_retry, num_invalid_try, pos", - [ - (0, 2, 1337), - (2, 3, 42), - (3, 2, 4016), - (2, 1, 999), - ], -) -def test_sync_read_num_retry(num_retry, num_invalid_try, pos, mock_motors, dummy_motors): - expected_position = {"dummy_1": pos} - stub_name = mock_motors.build_sync_read_stub( - "Present_Position", {1: pos}, num_invalid_try=num_invalid_try - ) - motors_bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) - - if num_retry >= num_invalid_try: - pos_dict = motors_bus.sync_read("Present_Position", "dummy_1", normalize=False, num_retry=num_retry) - assert pos_dict == expected_position + if raise_on_error: + with pytest.raises(RuntimeError, match=re.escape("[RxPacketError] Input voltage error!")): + motors_bus._read(addr, length, id_, raise_on_error=raise_on_error) else: - with pytest.raises(ConnectionError): - _ = motors_bus.sync_read("Present_Position", "dummy_1", normalize=False, num_retry=num_retry) - - expected_calls = min(1 + num_retry, 1 + num_invalid_try) - assert mock_motors.stubs[stub_name].calls == expected_calls - - -@pytest.mark.parametrize( - "data_name, value", - [ - ("Torque_Enable", 0), - ("Torque_Enable", 1), - ("Goal_Position", 1337), - ("Goal_Position", 42), - ], -) -def test_sync_write_single_value(data_name, value, mock_motors, dummy_motors): - ids_values = {m.id: value for m in dummy_motors.values()} - stub_name = mock_motors.build_sync_write_stub(data_name, ids_values) - motors_bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) - - motors_bus.sync_write(data_name, value, normalize=False) - - assert mock_motors.stubs[stub_name].wait_called() - - -@pytest.mark.parametrize( - "ids, positions", - [ - ([1], [1337]), - ([1, 2], [1337, 42]), - ([1, 2, 3], [1337, 42, 4016]), - ], - ids=["1 motor", "2 motors", "3 motors"], -) # fmt: skip -def test_sync_write(ids, positions, mock_motors, dummy_motors): - assert len(ids) == len(positions) - ids_values = dict(zip(ids, positions, strict=True)) - stub_name = mock_motors.build_sync_write_stub("Goal_Position", ids_values) - motors_bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) - - write_values = {f"dummy_{id_}": pos for id_, pos in ids_values.items()} - motors_bus.sync_write("Goal_Position", write_values, normalize=False) - - assert mock_motors.stubs[stub_name].wait_called() - - -@pytest.mark.parametrize( - "data_name, dxl_id, value", - [ - ("Torque_Enable", 1, 0), - ("Torque_Enable", 1, 1), - ("Goal_Position", 2, 1337), - ("Goal_Position", 3, 42), - ], -) -def test_write(data_name, dxl_id, value, mock_motors, dummy_motors): - stub_name = mock_motors.build_write_stub(data_name, dxl_id, value) - motors_bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) - - motors_bus.write(data_name, f"dummy_{dxl_id}", value, normalize=False) + _, _, read_error = motors_bus._read(addr, length, id_, raise_on_error=raise_on_error) + assert read_error == error assert mock_motors.stubs[stub_name].called +@pytest.mark.parametrize("raise_on_error", (True, False)) +def test__read_comm(raise_on_error, mock_motors, dummy_motors): + addr, length, id_, value = (10, 4, 1, 1337) + stub_name = mock_motors.build_read_stub(addr, length, id_, value, reply=False) + motors_bus = FeetechMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + motors_bus.connect(assert_motors_exist=False) + + if raise_on_error: + with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): + motors_bus._read(addr, length, id_, raise_on_error=raise_on_error) + else: + _, read_comm, _ = motors_bus._read(addr, length, id_, raise_on_error=raise_on_error) + assert read_comm == scs.COMM_RX_TIMEOUT + + assert mock_motors.stubs[stub_name].called + + +@pytest.mark.parametrize( + "addr, length, id_, value", + [ + (0, 1, 1, 2), + (10, 2, 2, 999), + (42, 4, 3, 1337), + ], +) +def test__write(addr, length, id_, value, mock_motors, dummy_motors): + stub_name = mock_motors.build_write_stub(addr, length, id_, value) + motors_bus = FeetechMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + motors_bus.connect(assert_motors_exist=False) + + comm, error = motors_bus._write(addr, length, id_, value) + + assert mock_motors.stubs[stub_name].called + assert comm == scs.COMM_SUCCESS + assert error == 0 + + +@pytest.mark.parametrize("raise_on_error", (True, False)) +def test__write_error(raise_on_error, mock_motors, dummy_motors): + addr, length, id_, value, error = (10, 4, 1, 1337, scs.ERRBIT_VOLTAGE) + stub_name = mock_motors.build_write_stub(addr, length, id_, value, error=error) + motors_bus = FeetechMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + motors_bus.connect(assert_motors_exist=False) + + if raise_on_error: + with pytest.raises(RuntimeError, match=re.escape("[RxPacketError] Input voltage error!")): + motors_bus._write(addr, length, id_, value, raise_on_error=raise_on_error) + else: + _, write_error = motors_bus._write(addr, length, id_, value, raise_on_error=raise_on_error) + assert write_error == error + + assert mock_motors.stubs[stub_name].called + + +@pytest.mark.parametrize("raise_on_error", (True, False)) +def test__write_comm(raise_on_error, mock_motors, dummy_motors): + addr, length, id_, value = (10, 4, 1, 1337) + stub_name = mock_motors.build_write_stub(addr, length, id_, value, reply=False) + motors_bus = FeetechMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + motors_bus.connect(assert_motors_exist=False) + + if raise_on_error: + with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): + motors_bus._write(addr, length, id_, value, raise_on_error=raise_on_error) + else: + write_comm, _ = motors_bus._write(addr, length, id_, value, raise_on_error=raise_on_error) + assert write_comm == scs.COMM_RX_TIMEOUT + + assert mock_motors.stubs[stub_name].called + + +@pytest.mark.parametrize( + "addr, length, ids_values", + [ + (0, 1, {1: 4}), + (10, 2, {1: 1337, 2: 42}), + (42, 4, {1: 1337, 2: 42, 3: 4016}), + ], + ids=["1 motor", "2 motors", "3 motors"], +) +def test__sync_read(addr, length, ids_values, mock_motors, dummy_motors): + stub_name = mock_motors.build_sync_read_stub(addr, length, ids_values) + motors_bus = FeetechMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + motors_bus.connect(assert_motors_exist=False) + + read_values, _ = motors_bus._sync_read(addr, length, list(ids_values)) + + assert mock_motors.stubs[stub_name].called + assert read_values == ids_values + + +@pytest.mark.parametrize("raise_on_error", (True, False)) +def test__sync_read_comm(raise_on_error, mock_motors, dummy_motors): + addr, length, ids_values = (10, 4, {1: 1337}) + stub_name = mock_motors.build_sync_read_stub(addr, length, ids_values, reply=False) + motors_bus = FeetechMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + motors_bus.connect(assert_motors_exist=False) + + if raise_on_error: + with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): + motors_bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error) + else: + _, read_comm = motors_bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error) + assert read_comm == scs.COMM_RX_TIMEOUT + + assert mock_motors.stubs[stub_name].called + + +@pytest.mark.parametrize( + "addr, length, ids_values", + [ + (0, 1, {1: 4}), + (10, 2, {1: 1337, 2: 42}), + (42, 4, {1: 1337, 2: 42, 3: 4016}), + ], + ids=["1 motor", "2 motors", "3 motors"], +) +def test__sync_write(addr, length, ids_values, mock_motors, dummy_motors): + stub_name = mock_motors.build_sync_write_stub(addr, length, ids_values) + motors_bus = FeetechMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + motors_bus.connect(assert_motors_exist=False) + + comm = motors_bus._sync_write(addr, length, ids_values) + + assert mock_motors.stubs[stub_name].wait_called() + assert comm == scs.COMM_SUCCESS + + def test_is_calibrated(mock_motors, dummy_motors, dummy_calibration): encoded_homings = {m.id: encode_sign_magnitude(m.homing_offset, 11) for m in dummy_calibration.values()} mins = {m.id: m.range_min for m in dummy_calibration.values()} maxes = {m.id: m.range_max for m in dummy_calibration.values()} - offsets_stub = mock_motors.build_sync_read_stub("Homing_Offset", encoded_homings) - mins_stub = mock_motors.build_sync_read_stub("Min_Position_Limit", mins) - maxes_stub = mock_motors.build_sync_read_stub("Max_Position_Limit", maxes) + offsets_stub = mock_motors.build_sync_read_stub( + *STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"], encoded_homings + ) + mins_stub = mock_motors.build_sync_read_stub(*STS_SMS_SERIES_CONTROL_TABLE["Min_Position_Limit"], mins) + maxes_stub = mock_motors.build_sync_read_stub(*STS_SMS_SERIES_CONTROL_TABLE["Max_Position_Limit"], maxes) motors_bus = FeetechMotorsBus( port=mock_motors.port, motors=dummy_motors, @@ -340,9 +371,15 @@ def test_reset_calibration(mock_motors, dummy_motors): write_mins_stubs = [] write_maxes_stubs = [] for motor in dummy_motors.values(): - write_homing_stubs.append(mock_motors.build_write_stub("Homing_Offset", motor.id, 0)) - write_mins_stubs.append(mock_motors.build_write_stub("Min_Position_Limit", motor.id, 0)) - write_maxes_stubs.append(mock_motors.build_write_stub("Max_Position_Limit", motor.id, 4095)) + write_homing_stubs.append( + mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"], motor.id, 0) + ) + write_mins_stubs.append( + mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Min_Position_Limit"], motor.id, 0) + ) + write_maxes_stubs.append( + mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Max_Position_Limit"], motor.id, 4095) + ) motors_bus = FeetechMotorsBus( port=mock_motors.port, @@ -372,11 +409,15 @@ def test_set_half_turn_homings(mock_motors, dummy_motors): 2: -2005, # 42 - 2047 3: 1625, # 3672 - 2047 } - read_pos_stub = mock_motors.build_sync_read_stub("Present_Position", current_positions) + read_pos_stub = mock_motors.build_sync_read_stub( + *STS_SMS_SERIES_CONTROL_TABLE["Present_Position"], current_positions + ) write_homing_stubs = [] for id_, homing in expected_homings.items(): encoded_homing = encode_sign_magnitude(homing, 11) - stub = mock_motors.build_write_stub("Homing_Offset", id_, encoded_homing) + stub = mock_motors.build_write_stub( + *STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"], id_, encoded_homing + ) write_homing_stubs.append(stub) motors_bus = FeetechMotorsBus( @@ -409,7 +450,9 @@ def test_record_ranges_of_motion(mock_motors, dummy_motors): "dummy_2": 3600, "dummy_3": 4002, } - read_pos_stub = mock_motors.build_sequential_sync_read_stub("Present_Position", positions) + read_pos_stub = mock_motors.build_sequential_sync_read_stub( + *STS_SMS_SERIES_CONTROL_TABLE["Present_Position"], positions + ) with patch("lerobot.common.motors.motors_bus.enter_pressed", side_effect=[False, True]): motors_bus = FeetechMotorsBus( port=mock_motors.port, diff --git a/tests/motors/test_motors_bus.py b/tests/motors/test_motors_bus.py index 7797622ee..c98cda7dd 100644 --- a/tests/motors/test_motors_bus.py +++ b/tests/motors/test_motors_bus.py @@ -1,9 +1,13 @@ +# ruff: noqa: N802 + import re +from unittest.mock import patch import pytest from lerobot.common.motors.motors_bus import ( Motor, + MotorNormMode, MotorsBus, assert_same_address, get_address, @@ -14,30 +18,35 @@ DUMMY_CTRL_TABLE_1 = { "Firmware_Version": (0, 1), "Model_Number": (1, 2), "Present_Position": (3, 4), - "Goal_Position": (7, 2), + "Goal_Position": (11, 2), } DUMMY_CTRL_TABLE_2 = { "Model_Number": (0, 2), "Firmware_Version": (2, 1), "Present_Position": (3, 4), - "Goal_Position": (7, 4), - "Lock": (7, 4), + "Present_Velocity": (7, 4), + "Goal_Position": (11, 4), + "Goal_Velocity": (15, 4), + "Lock": (19, 1), } DUMMY_MODEL_CTRL_TABLE = { "model_1": DUMMY_CTRL_TABLE_1, "model_2": DUMMY_CTRL_TABLE_2, + "model_3": DUMMY_CTRL_TABLE_2, } DUMMY_BAUDRATE_TABLE = { 0: 1_000_000, 1: 500_000, + 2: 250_000, } DUMMY_MODEL_BAUDRATE_TABLE = { "model_1": DUMMY_BAUDRATE_TABLE, "model_2": DUMMY_BAUDRATE_TABLE, + "model_3": DUMMY_BAUDRATE_TABLE, } DUMMY_ENCODING_TABLE = { @@ -48,21 +57,78 @@ DUMMY_ENCODING_TABLE = { DUMMY_MODEL_ENCODING_TABLE = { "model_1": DUMMY_ENCODING_TABLE, "model_2": DUMMY_ENCODING_TABLE, + "model_3": DUMMY_ENCODING_TABLE, +} + +DUMMY_MODEL_NUMBER_TABLE = { + "model_1": 1234, + "model_2": 5678, + "model_3": 5799, +} + +DUMMY_MODEL_RESOLUTION_TABLE = { + "model_1": 4096, + "model_2": 1024, + "model_3": 4096, } -class DummyMotorsBus(MotorsBus): +class MockPortHandler: + def __init__(self, port_name): + self.is_open: bool = False + self.baudrate: int + self.packet_start_time: float + self.packet_timeout: float + self.tx_time_per_byte: float + self.is_using: bool = False + self.port_name: str = port_name + self.ser = None + + def openPort(self): + self.is_open = True + return self.is_open + + def closePort(self): + self.is_open = False + + def clearPort(self): ... + def setPortName(self, port_name): + self.port_name = port_name + + def getPortName(self): + return self.port_name + + def setBaudRate(self, baudrate): + self.baudrate: baudrate + + def getBaudRate(self): + return self.baudrate + + def getBytesAvailable(self): ... + def readPort(self, length): ... + def writePort(self, packet): ... + def setPacketTimeout(self, packet_length): ... + def setPacketTimeoutMillis(self, msec): ... + def isPacketTimeout(self): ... + def getCurrentTime(self): ... + def getTimeSinceStart(self): ... + def setupPort(self, cflag_baud): ... + def getCFlagBaud(self, baudrate): ... + + +class MockMotorsBus(MotorsBus): available_baudrates = [500_000, 1_000_000] default_timeout = 1000 model_baudrate_table = DUMMY_MODEL_BAUDRATE_TABLE model_ctrl_table = DUMMY_MODEL_CTRL_TABLE model_encoding_table = DUMMY_MODEL_ENCODING_TABLE - model_number_table = {"model_1": 1234, "model_2": 5678} - model_resolution_table = {"model_1": 4096, "model_2": 1024} + model_number_table = DUMMY_MODEL_NUMBER_TABLE + model_resolution_table = DUMMY_MODEL_RESOLUTION_TABLE normalized_data = ["Present_Position", "Goal_Position"] def __init__(self, port: str, motors: dict[str, Motor]): super().__init__(port, motors) + self.port_handler = MockPortHandler(port) def _assert_protocol_is_compatible(self, instruction_name): ... def configure_motors(self): ... @@ -75,6 +141,15 @@ class DummyMotorsBus(MotorsBus): def broadcast_ping(self, num_retry, raise_on_error): ... +@pytest.fixture +def dummy_motors() -> dict[str, Motor]: + return { + "dummy_1": Motor(1, "model_2", MotorNormMode.RANGE_M100_100), + "dummy_2": Motor(2, "model_3", MotorNormMode.RANGE_M100_100), + "dummy_3": Motor(3, "model_2", MotorNormMode.RANGE_0_100), + } + + def test_get_ctrl_table(): model = "model_1" ctrl_table = get_ctrl_table(DUMMY_MODEL_CTRL_TABLE, model) @@ -105,7 +180,7 @@ def test_assert_same_address(): assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Present_Position") -def test_assert_same_address_different_addresses(): +def test_assert_same_length_different_addresses(): models = ["model_1", "model_2"] with pytest.raises( NotImplementedError, @@ -114,7 +189,7 @@ def test_assert_same_address_different_addresses(): assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Model_Number") -def test_assert_same_address_different_bytes(): +def test_assert_same_address_different_length(): models = ["model_1", "model_2"] with pytest.raises( NotImplementedError, @@ -124,18 +199,267 @@ def test_assert_same_address_different_bytes(): def test__serialize_data_invalid_length(): - bus = DummyMotorsBus("", {}) + bus = MockMotorsBus("", {}) with pytest.raises(NotImplementedError): bus._serialize_data(100, 3) def test__serialize_data_negative_numbers(): - bus = DummyMotorsBus("", {}) + bus = MockMotorsBus("", {}) with pytest.raises(ValueError): bus._serialize_data(-1, 1) def test__serialize_data_large_number(): - bus = DummyMotorsBus("", {}) + bus = MockMotorsBus("", {}) with pytest.raises(ValueError): bus._serialize_data(2**32, 4) # 4-byte max is 0xFFFFFFFF + + +@pytest.mark.parametrize( + "data_name, id_, value", + [ + ("Firmware_Version", 1, 14), + ("Model_Number", 1, 5678), + ("Present_Position", 2, 1337), + ("Present_Velocity", 3, 42), + ], +) +def test_read(data_name, id_, value, dummy_motors): + bus = MockMotorsBus("/dev/dummy-port", dummy_motors) + bus.connect(assert_motors_exist=False) + addr, length = DUMMY_CTRL_TABLE_2[data_name] + + with ( + patch.object(MockMotorsBus, "_read", return_value=(value, 0, 0)) as mock__read, + patch.object(MockMotorsBus, "_decode_sign", return_value={id_: value}) as mock__decode_sign, + patch.object(MockMotorsBus, "_normalize", return_value={id_: value}) as mock__normalize, + ): + returned_value = bus.read(data_name, f"dummy_{id_}") + + assert returned_value == value + mock__read.assert_called_once_with( + addr, + length, + id_, + num_retry=0, + raise_on_error=True, + err_msg=f"Failed to read '{data_name}' on {id_=} after 1 tries.", + ) + mock__decode_sign.assert_called_once_with(data_name, {id_: value}) + if data_name in bus.normalized_data: + mock__normalize.assert_called_once_with(data_name, {id_: value}) + + +@pytest.mark.parametrize( + "data_name, id_, value", + [ + ("Goal_Position", 1, 1337), + ("Goal_Velocity", 2, 3682), + ("Lock", 3, 1), + ], +) +def test_write(data_name, id_, value, dummy_motors): + bus = MockMotorsBus("/dev/dummy-port", dummy_motors) + bus.connect(assert_motors_exist=False) + addr, length = DUMMY_CTRL_TABLE_2[data_name] + + with ( + patch.object(MockMotorsBus, "_write", return_value=(0, 0)) as mock__write, + patch.object(MockMotorsBus, "_encode_sign", return_value={id_: value}) as mock__encode_sign, + patch.object(MockMotorsBus, "_unnormalize", return_value={id_: value}) as mock__unnormalize, + ): + bus.write(data_name, f"dummy_{id_}", value) + + mock__write.assert_called_once_with( + addr, + length, + id_, + value, + num_retry=0, + raise_on_error=True, + err_msg=f"Failed to write '{data_name}' on {id_=} with '{value}' after 1 tries.", + ) + mock__encode_sign.assert_called_once_with(data_name, {id_: value}) + if data_name in bus.normalized_data: + mock__unnormalize.assert_called_once_with(data_name, {id_: value}) + + +@pytest.mark.parametrize( + "data_name, id_, value", + [ + ("Firmware_Version", 1, 14), + ("Model_Number", 1, 5678), + ("Present_Position", 2, 1337), + ("Present_Velocity", 3, 42), + ], +) +def test_sync_read_by_str(data_name, id_, value, dummy_motors): + bus = MockMotorsBus("/dev/dummy-port", dummy_motors) + bus.connect(assert_motors_exist=False) + addr, length = DUMMY_CTRL_TABLE_2[data_name] + ids = [id_] + expected_value = {f"dummy_{id_}": value} + + with ( + patch.object(MockMotorsBus, "_sync_read", return_value=({id_: value}, 0)) as mock__sync_read, + patch.object(MockMotorsBus, "_decode_sign", return_value={id_: value}) as mock__decode_sign, + patch.object(MockMotorsBus, "_normalize", return_value={id_: value}) as mock__normalize, + ): + returned_dict = bus.sync_read(data_name, f"dummy_{id_}") + + assert returned_dict == expected_value + mock__sync_read.assert_called_once_with( + addr, + length, + ids, + num_retry=0, + raise_on_error=True, + err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.", + ) + mock__decode_sign.assert_called_once_with(data_name, {id_: value}) + if data_name in bus.normalized_data: + mock__normalize.assert_called_once_with(data_name, {id_: value}) + + +@pytest.mark.parametrize( + "data_name, ids_values", + [ + ("Model_Number", {1: 5678}), + ("Present_Position", {1: 1337, 2: 42}), + ("Present_Velocity", {1: 1337, 2: 42, 3: 4016}), + ], + ids=["1 motor", "2 motors", "3 motors"], +) +def test_sync_read_by_list(data_name, ids_values, dummy_motors): + bus = MockMotorsBus("/dev/dummy-port", dummy_motors) + bus.connect(assert_motors_exist=False) + addr, length = DUMMY_CTRL_TABLE_2[data_name] + ids = list(ids_values) + expected_values = {f"dummy_{id_}": val for id_, val in ids_values.items()} + + with ( + patch.object(MockMotorsBus, "_sync_read", return_value=(ids_values, 0)) as mock__sync_read, + patch.object(MockMotorsBus, "_decode_sign", return_value=ids_values) as mock__decode_sign, + patch.object(MockMotorsBus, "_normalize", return_value=ids_values) as mock__normalize, + ): + returned_dict = bus.sync_read(data_name, [f"dummy_{id_}" for id_ in ids]) + + assert returned_dict == expected_values + mock__sync_read.assert_called_once_with( + addr, + length, + ids, + num_retry=0, + raise_on_error=True, + err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.", + ) + mock__decode_sign.assert_called_once_with(data_name, ids_values) + if data_name in bus.normalized_data: + mock__normalize.assert_called_once_with(data_name, ids_values) + + +@pytest.mark.parametrize( + "data_name, ids_values", + [ + ("Model_Number", {1: 5678, 2: 5799, 3: 5678}), + ("Present_Position", {1: 1337, 2: 42, 3: 4016}), + ("Goal_Position", {1: 4008, 2: 199, 3: 3446}), + ], + ids=["Model_Number", "Present_Position", "Goal_Position"], +) +def test_sync_read_by_none(data_name, ids_values, dummy_motors): + bus = MockMotorsBus("/dev/dummy-port", dummy_motors) + bus.connect(assert_motors_exist=False) + addr, length = DUMMY_CTRL_TABLE_2[data_name] + ids = list(ids_values) + expected_values = {f"dummy_{id_}": val for id_, val in ids_values.items()} + + with ( + patch.object(MockMotorsBus, "_sync_read", return_value=(ids_values, 0)) as mock__sync_read, + patch.object(MockMotorsBus, "_decode_sign", return_value=ids_values) as mock__decode_sign, + patch.object(MockMotorsBus, "_normalize", return_value=ids_values) as mock__normalize, + ): + returned_dict = bus.sync_read(data_name) + + assert returned_dict == expected_values + mock__sync_read.assert_called_once_with( + addr, + length, + ids, + num_retry=0, + raise_on_error=True, + err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.", + ) + mock__decode_sign.assert_called_once_with(data_name, ids_values) + if data_name in bus.normalized_data: + mock__normalize.assert_called_once_with(data_name, ids_values) + + +@pytest.mark.parametrize( + "data_name, value", + [ + ("Goal_Position", 500), + ("Goal_Velocity", 4010), + ("Lock", 0), + ], +) +def test_sync_write_by_single_value(data_name, value, dummy_motors): + bus = MockMotorsBus("/dev/dummy-port", dummy_motors) + bus.connect(assert_motors_exist=False) + addr, length = DUMMY_CTRL_TABLE_2[data_name] + ids_values = {m.id: value for m in dummy_motors.values()} + + with ( + patch.object(MockMotorsBus, "_sync_write", return_value=(ids_values, 0)) as mock__sync_write, + patch.object(MockMotorsBus, "_encode_sign", return_value=ids_values) as mock__encode_sign, + patch.object(MockMotorsBus, "_unnormalize", return_value=ids_values) as mock__unnormalize, + ): + bus.sync_write(data_name, value) + + mock__sync_write.assert_called_once_with( + addr, + length, + ids_values, + num_retry=0, + raise_on_error=True, + err_msg=f"Failed to sync write '{data_name}' with {ids_values=} after 1 tries.", + ) + mock__encode_sign.assert_called_once_with(data_name, ids_values) + if data_name in bus.normalized_data: + mock__unnormalize.assert_called_once_with(data_name, ids_values) + + +@pytest.mark.parametrize( + "data_name, ids_values", + [ + ("Goal_Position", {1: 1337, 2: 42, 3: 4016}), + ("Goal_Velocity", {1: 50, 2: 83, 3: 2777}), + ("Lock", {1: 0, 2: 0, 3: 1}), + ], + ids=["Goal_Position", "Goal_Velocity", "Lock"], +) +def test_sync_write_by_value_dict(data_name, ids_values, dummy_motors): + bus = MockMotorsBus("/dev/dummy-port", dummy_motors) + bus.connect(assert_motors_exist=False) + addr, length = DUMMY_CTRL_TABLE_2[data_name] + values = {f"dummy_{id_}": val for id_, val in ids_values.items()} + + with ( + patch.object(MockMotorsBus, "_sync_write", return_value=(ids_values, 0)) as mock__sync_write, + patch.object(MockMotorsBus, "_encode_sign", return_value=ids_values) as mock__encode_sign, + patch.object(MockMotorsBus, "_unnormalize", return_value=ids_values) as mock__unnormalize, + ): + bus.sync_write(data_name, values) + + mock__sync_write.assert_called_once_with( + addr, + length, + ids_values, + num_retry=0, + raise_on_error=True, + err_msg=f"Failed to sync write '{data_name}' with {ids_values=} after 1 tries.", + ) + mock__encode_sign.assert_called_once_with(data_name, ids_values) + if data_name in bus.normalized_data: + mock__unnormalize.assert_called_once_with(data_name, ids_values) From d70bc4bde940af3108581a115e0202a564db3bcb Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Mon, 14 Apr 2025 15:16:38 +0200 Subject: [PATCH 18/23] Add more segmented tests (dynamixel) --- .gitignore | 2 +- lerobot/common/motors/dynamixel/dynamixel.py | 4 +- tests/mocks/mock_dynamixel.py | 186 +++++++----- tests/motors/test_dynamixel.py | 290 +++++++++++-------- 4 files changed, 275 insertions(+), 207 deletions(-) diff --git a/.gitignore b/.gitignore index d6c51c90d..42f2e7552 100644 --- a/.gitignore +++ b/.gitignore @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +.dev # Logging logs tmp diff --git a/lerobot/common/motors/dynamixel/dynamixel.py b/lerobot/common/motors/dynamixel/dynamixel.py index 1ebefac07..21f2524c4 100644 --- a/lerobot/common/motors/dynamixel/dynamixel.py +++ b/lerobot/common/motors/dynamixel/dynamixel.py @@ -143,11 +143,11 @@ class DynamixelMotorsBus(MotorsBus): self.write("Return_Delay_Time", id_, 0) def disable_torque(self, motors: str | list[str] | None = None) -> None: - for name in self._get_names_list(motors): + for name in self._get_motors_list(motors): self.write("Torque_Enable", name, TorqueMode.DISABLED.value) def enable_torque(self, motors: str | list[str] | None = None) -> None: - for name in self._get_names_list(motors): + for name in self._get_motors_list(motors): self.write("Torque_Enable", name, TorqueMode.ENABLED.value) def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]: diff --git a/tests/mocks/mock_dynamixel.py b/tests/mocks/mock_dynamixel.py index 1c1ab6fec..a9d434e95 100644 --- a/tests/mocks/mock_dynamixel.py +++ b/tests/mocks/mock_dynamixel.py @@ -46,41 +46,6 @@ DXL_CRC_TABLE = [ 0x8213, 0x0216, 0x021C, 0x8219, 0x0208, 0x820D, 0x8207, 0x0202 ] # fmt: skip -# https://emanual.robotis.com/docs/en/dxl/protocol2/#instruction -INSTRUCTION_TYPES = { - "Ping": dxl.INST_PING, # Checks whether the Packet has arrived at a device with the same ID as the specified packet ID - "Read": dxl.INST_READ, # Read data from the Device - "Write": dxl.INST_WRITE, # Write data to the Device - "Reg_Write": dxl.INST_REG_WRITE, # Register the Instruction Packet in standby status; Packet can later be executed using the Action command - "Action": dxl.INST_ACTION, # Executes a Packet that was registered beforehand using Reg Write - "Factory_Reset": dxl.INST_FACTORY_RESET, # Resets the Control Table to its initial factory default settings - "Reboot": dxl.INST_REBOOT, # Reboot the Device - "Clear": dxl.INST_CLEAR, # Reset certain information stored in memory - "Control_Table_Backup": 0x20, # Store current Control Table status data to a Backup or to restore backup EEPROM data. - "Status": dxl.INST_STATUS, # Return packet sent following the execution of an Instruction Packet - "Sync_Read": dxl.INST_SYNC_READ, # Read data from multiple devices with the same Address with the same length at once - "Sync_Write": dxl.INST_SYNC_WRITE, # Write data to multiple devices with the same Address with the same length at once - "Fast_Sync_Read": 0x8A, # Read data from multiple devices with the same Address with the same length at once - "Bulk_Read": dxl.INST_BULK_READ, # Read data from multiple devices with different Addresses with different lengths at once - "Bulk_Write": dxl.INST_BULK_WRITE, # Write data to multiple devices with different Addresses with different lengths at once - "Fast_Bulk_Read": 0x9A, # Read data from multiple devices with different Addresses with different lengths at once -} # fmt: skip - -# https://emanual.robotis.com/docs/en/dxl/protocol2/#error -ERROR_TYPE = { - "Success": 0x00, # No error - "Result_Fail": dxl.ERRNUM_RESULT_FAIL, # Failed to process the sent Instruction Packet - "Instruction_Error": dxl.ERRNUM_INSTRUCTION, # An undefined Instruction has been usedAction has been used without Reg Write - "CRC_Error": dxl.ERRNUM_CRC, # The CRC of the sent Packet does not match the expected value - "Data_Range_Error": dxl.ERRNUM_DATA_RANGE, # Data to be written to the specified Address is outside the range of the minimum/maximum value - "Data_Length_Error": dxl.ERRNUM_DATA_LENGTH, # Attempted to write Data that is shorter than the required data length of the specified Address - # (ex: when you attempt to only use 2 bytes of a register that has been defined as 4 bytes) - "Data_Limit_Error": dxl.ERRNUM_DATA_LIMIT, # Data to be written to the specified Address is outside of the configured Limit value - "Access_Error": dxl.ERRNUM_ACCESS, # Attempted to write a value to an Address that is Read Only or has not been defined - # Attempted to read a value from an Address that is Write Only or has not been defined - # Attempted to write a value to an EEPROM register while Torque was Enabled. -} # fmt: skip - class MockDynamixelPacketv2(abc.ABC): @classmethod @@ -187,14 +152,14 @@ class MockInstructionPacket(MockDynamixelPacketv2): """ @classmethod - def _build(cls, dxl_id: int, params: list[int], length: int, instruct_type: str) -> list[int]: - instruct_value = INSTRUCTION_TYPES[instruct_type] + def _build(cls, dxl_id: int, params: list[int], length: int, instruction: int) -> list[int]: + length = len(params) + 3 return [ 0xFF, 0xFF, 0xFD, 0x00, # header dxl_id, # servo id dxl.DXL_LOBYTE(length), # length_l dxl.DXL_HIBYTE(length), # length_h - instruct_value, # instruction type + instruction, # instruction type *params, # data bytes 0x00, 0x00 # placeholder for CRC ] # fmt: skip @@ -210,8 +175,39 @@ class MockInstructionPacket(MockDynamixelPacketv2): No parameters required. """ - params, length = [], 3 - return cls.build(dxl_id=dxl_id, params=params, length=length, instruct_type="Ping") + return cls.build(dxl_id=dxl_id, params=[], length=3, instruction=dxl.INST_PING) + + @classmethod + def read( + cls, + dxl_id: int, + start_address: int, + data_length: int, + ) -> bytes: + """ + Builds a "Read" instruction. + https://emanual.robotis.com/docs/en/dxl/protocol2/#read-0x02 + + The parameters for Read (Protocol 2.0) are: + param[0] = start_address L + param[1] = start_address H + param[2] = data_length L + param[3] = data_length H + + And 'length' = data_length + 5, where: + +1 is for instruction byte, + +2 is for the length bytes, + +2 is for the CRC at the end. + """ + params = [ + dxl.DXL_LOBYTE(start_address), + dxl.DXL_HIBYTE(start_address), + dxl.DXL_LOBYTE(data_length), + dxl.DXL_HIBYTE(data_length), + ] + length = len(params) + 3 + # length = data_length + 5 + return cls.build(dxl_id=dxl_id, params=params, length=length, instruction=dxl.INST_READ) @classmethod def write( @@ -245,7 +241,7 @@ class MockInstructionPacket(MockDynamixelPacketv2): *data, ] length = data_length + 5 - return cls.build(dxl_id=dxl_id, params=params, length=length, instruct_type="Write") + return cls.build(dxl_id=dxl_id, params=params, length=length, instruction=dxl.INST_WRITE) @classmethod def sync_read( @@ -279,7 +275,9 @@ class MockInstructionPacket(MockDynamixelPacketv2): *dxl_ids, ] length = len(dxl_ids) + 7 - return cls.build(dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Read") + return cls.build( + dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruction=dxl.INST_SYNC_READ + ) @classmethod def sync_write( @@ -326,7 +324,9 @@ class MockInstructionPacket(MockDynamixelPacketv2): *data, ] length = len(ids_values) * (1 + data_length) + 7 - return cls.build(dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Write") + return cls.build( + dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruction=dxl.INST_SYNC_WRITE + ) class MockStatusPacket(MockDynamixelPacketv2): @@ -342,21 +342,20 @@ class MockStatusPacket(MockDynamixelPacketv2): """ @classmethod - def _build(cls, dxl_id: int, params: list[int], length: int, error: str = "Success") -> list[int]: - err_byte = ERROR_TYPE[error] + def _build(cls, dxl_id: int, params: list[int], length: int, error: int = 0) -> list[int]: return [ 0xFF, 0xFF, 0xFD, 0x00, # header dxl_id, # servo id dxl.DXL_LOBYTE(length), # length_l dxl.DXL_HIBYTE(length), # length_h 0x55, # instruction = 'status' - err_byte, # error + error, # error *params, # data bytes 0x00, 0x00 # placeholder for CRC ] # fmt: skip @classmethod - def ping(cls, dxl_id: int, model_nb: int = 1190, firm_ver: int = 50) -> bytes: + def ping(cls, dxl_id: int, model_nb: int = 1190, firm_ver: int = 50, error: int = 0) -> bytes: """ Builds a 'Ping' status packet. https://emanual.robotis.com/docs/en/dxl/protocol2/#ping-0x01 @@ -373,10 +372,10 @@ class MockStatusPacket(MockDynamixelPacketv2): """ params = [dxl.DXL_LOBYTE(model_nb), dxl.DXL_HIBYTE(model_nb), firm_ver] length = 7 - return cls.build(dxl_id, params=params, length=length) + return cls.build(dxl_id, params=params, length=length, error=error) @classmethod - def read(cls, dxl_id: int, value: int, param_length: int) -> bytes: + def read(cls, dxl_id: int, value: int, param_length: int, error: int = 0) -> bytes: """ Builds a 'Read' status packet (also works for 'Sync Read') https://emanual.robotis.com/docs/en/dxl/protocol2/#read-0x02 @@ -392,7 +391,7 @@ class MockStatusPacket(MockDynamixelPacketv2): """ params = _split_into_byte_chunks(value, param_length) length = param_length + 4 - return cls.build(dxl_id, params=params, length=length) + return cls.build(dxl_id, params=params, length=length, error=error) class MockPortHandler(dxl.PortHandler): @@ -456,10 +455,10 @@ class MockMotors(MockSerial): return stub_name def build_ping_stub( - self, dxl_id: int, model_nb: int, firm_ver: int = 50, num_invalid_try: int = 0 + self, dxl_id: int, model_nb: int, firm_ver: int = 50, num_invalid_try: int = 0, error: int = 0 ) -> str: ping_request = MockInstructionPacket.ping(dxl_id) - return_packet = MockStatusPacket.ping(dxl_id, model_nb, firm_ver) + return_packet = MockStatusPacket.ping(dxl_id, model_nb, firm_ver, error) ping_response = self._build_send_fn(return_packet, num_invalid_try) stub_name = f"Ping_{dxl_id}" self.stub( @@ -469,14 +468,63 @@ class MockMotors(MockSerial): ) return stub_name - def build_sync_read_stub( - self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0 + def build_read_stub( + self, + address: int, + length: int, + dxl_id: int, + value: int, + reply: bool = True, + error: int = 0, + num_invalid_try: int = 0, + ) -> str: + read_request = MockInstructionPacket.read(dxl_id, address, length) + return_packet = MockStatusPacket.read(dxl_id, value, length, error) if reply else b"" + read_response = self._build_send_fn(return_packet, num_invalid_try) + stub_name = f"Read_{address}_{length}_{dxl_id}_{value}_{error}" + self.stub( + name=stub_name, + receive_bytes=read_request, + send_fn=read_response, + ) + return stub_name + + def build_write_stub( + self, + address: int, + length: int, + dxl_id: int, + value: int, + reply: bool = True, + error: int = 0, + num_invalid_try: int = 0, + ) -> str: + sync_read_request = MockInstructionPacket.write(dxl_id, value, address, length) + return_packet = MockStatusPacket.build(dxl_id, params=[], length=4, error=error) if reply else b"" + stub_name = f"Write_{address}_{length}_{dxl_id}" + self.stub( + name=stub_name, + receive_bytes=sync_read_request, + send_fn=self._build_send_fn(return_packet, num_invalid_try), + ) + return stub_name + + def build_sync_read_stub( + self, + address: int, + length: int, + ids_values: dict[int, int], + reply: bool = True, + num_invalid_try: int = 0, ) -> str: - address, length = self.ctrl_table[data_name] sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length) - return_packets = b"".join(MockStatusPacket.read(id_, pos, length) for id_, pos in ids_values.items()) + return_packets = ( + b"".join(MockStatusPacket.read(id_, pos, length) for id_, pos in ids_values.items()) + if reply + else b"" + ) sync_read_response = self._build_send_fn(return_packets, num_invalid_try) - stub_name = f"Sync_Read_{data_name}_" + "_".join([str(id_) for id_ in ids_values]) + stub_name = f"Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values]) self.stub( name=stub_name, receive_bytes=sync_read_request, @@ -485,11 +533,10 @@ class MockMotors(MockSerial): return stub_name def build_sequential_sync_read_stub( - self, data_name: str, ids_values: dict[int, list[int]] | None = None + self, address: int, length: int, ids_values: dict[int, list[int]] | None = None ) -> str: sequence_length = len(next(iter(ids_values.values()))) assert all(len(positions) == sequence_length for positions in ids_values.values()) - address, length = self.ctrl_table[data_name] sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length) sequential_packets = [] for count in range(sequence_length): @@ -499,7 +546,7 @@ class MockMotors(MockSerial): sequential_packets.append(return_packets) sync_read_response = self._build_sequential_send_fn(sequential_packets) - stub_name = f"Seq_Sync_Read_{data_name}_" + "_".join([str(id_) for id_ in ids_values]) + stub_name = f"Seq_Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values]) self.stub( name=stub_name, receive_bytes=sync_read_request, @@ -508,11 +555,10 @@ class MockMotors(MockSerial): return stub_name def build_sync_write_stub( - self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0 + self, address: int, length: int, ids_values: dict[int, int], num_invalid_try: int = 0 ) -> str: - address, length = self.ctrl_table[data_name] sync_read_request = MockInstructionPacket.sync_write(ids_values, address, length) - stub_name = f"Sync_Write_{data_name}_" + "_".join([str(id_) for id_ in ids_values]) + stub_name = f"Sync_Write_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values]) self.stub( name=stub_name, receive_bytes=sync_read_request, @@ -520,20 +566,6 @@ class MockMotors(MockSerial): ) return stub_name - def build_write_stub( - self, data_name: str, dxl_id: int, value: int, error: str = "Success", num_invalid_try: int = 0 - ) -> str: - address, length = self.ctrl_table[data_name] - sync_read_request = MockInstructionPacket.write(dxl_id, value, address, length) - return_packet = MockStatusPacket.build(dxl_id, params=[], length=4, error=error) - stub_name = f"Write_{data_name}_{dxl_id}" - self.stub( - name=stub_name, - receive_bytes=sync_read_request, - send_fn=self._build_send_fn(return_packet, num_invalid_try), - ) - return stub_name - @staticmethod def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]: def send_fn(_call_count: int) -> bytes: diff --git a/tests/motors/test_dynamixel.py b/tests/motors/test_dynamixel.py index 2b7088360..163af2d16 100644 --- a/tests/motors/test_dynamixel.py +++ b/tests/motors/test_dynamixel.py @@ -1,3 +1,4 @@ +import re import sys from typing import Generator from unittest.mock import MagicMock, patch @@ -7,6 +8,7 @@ import pytest from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode from lerobot.common.motors.dynamixel import MODEL_NUMBER_TABLE, DynamixelMotorsBus +from lerobot.common.motors.dynamixel.tables import X_SERIES_CONTROL_TABLE from lerobot.common.utils.encoding_utils import encode_twos_complement from tests.mocks.mock_dynamixel import MockMotors, MockPortHandler @@ -87,7 +89,7 @@ def test_abc_implementation(dummy_motors): @pytest.mark.parametrize("id_", [1, 2, 3]) def test_ping(id_, mock_motors, dummy_motors): expected_model_nb = MODEL_NUMBER_TABLE[dummy_motors[f"dummy_{id_}"].model] - stub_name = mock_motors.build_ping_stub(id_, expected_model_nb) + stub = mock_motors.build_ping_stub(id_, expected_model_nb) motors_bus = DynamixelMotorsBus( port=mock_motors.port, motors=dummy_motors, @@ -97,13 +99,13 @@ def test_ping(id_, mock_motors, dummy_motors): ping_model_nb = motors_bus.ping(id_) assert ping_model_nb == expected_model_nb - assert mock_motors.stubs[stub_name].called + assert mock_motors.stubs[stub].called def test_broadcast_ping(mock_motors, dummy_motors): models = {m.id: m.model for m in dummy_motors.values()} expected_model_nbs = {id_: MODEL_NUMBER_TABLE[model] for id_, model in models.items()} - stub_name = mock_motors.build_broadcast_ping_stub(expected_model_nbs) + stub = mock_motors.build_broadcast_ping_stub(expected_model_nbs) motors_bus = DynamixelMotorsBus( port=mock_motors.port, motors=dummy_motors, @@ -113,178 +115,202 @@ def test_broadcast_ping(mock_motors, dummy_motors): ping_model_nbs = motors_bus.broadcast_ping() assert ping_model_nbs == expected_model_nbs - assert mock_motors.stubs[stub_name].called - - -def test_sync_read_none(mock_motors, dummy_motors): - expected_positions = { - "dummy_1": 1337, - "dummy_2": 42, - "dummy_3": 4016, - } - ids_values = dict(zip([1, 2, 3], expected_positions.values(), strict=True)) - stub_name = mock_motors.build_sync_read_stub("Present_Position", ids_values) - motors_bus = DynamixelMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) - - read_positions = motors_bus.sync_read("Present_Position", normalize=False) - - assert mock_motors.stubs[stub_name].called - assert read_positions == expected_positions + assert mock_motors.stubs[stub].called @pytest.mark.parametrize( - "id_, position", + "addr, length, id_, value", [ - (1, 1337), - (2, 42), - (3, 4016), + (0, 1, 1, 2), + (10, 2, 2, 999), + (42, 4, 3, 1337), ], ) -def test_sync_read_single_value(id_, position, mock_motors, dummy_motors): - expected_position = {f"dummy_{id_}": position} - stub_name = mock_motors.build_sync_read_stub("Present_Position", {id_: position}) +def test__read(addr, length, id_, value, mock_motors, dummy_motors): + stub = mock_motors.build_read_stub(addr, length, id_, value) motors_bus = DynamixelMotorsBus( port=mock_motors.port, motors=dummy_motors, ) motors_bus.connect(assert_motors_exist=False) - read_position = motors_bus.sync_read("Present_Position", f"dummy_{id_}", normalize=False) + read_value, _, _ = motors_bus._read(addr, length, id_) - assert mock_motors.stubs[stub_name].called - assert read_position == expected_position + assert mock_motors.stubs[stub].called + assert read_value == value -@pytest.mark.parametrize( - "ids, positions", - [ - ([1], [1337]), - ([1, 2], [1337, 42]), - ([1, 2, 3], [1337, 42, 4016]), - ], - ids=["1 motor", "2 motors", "3 motors"], -) # fmt: skip -def test_sync_read(ids, positions, mock_motors, dummy_motors): - assert len(ids) == len(positions) - names = [f"dummy_{dxl_id}" for dxl_id in ids] - expected_positions = dict(zip(names, positions, strict=True)) - ids_values = dict(zip(ids, positions, strict=True)) - stub_name = mock_motors.build_sync_read_stub("Present_Position", ids_values) +@pytest.mark.parametrize("raise_on_error", (True, False)) +def test__read_error(raise_on_error, mock_motors, dummy_motors): + addr, length, id_, value, error = (10, 4, 1, 1337, dxl.ERRNUM_DATA_LIMIT) + stub = mock_motors.build_read_stub(addr, length, id_, value, error=error) motors_bus = DynamixelMotorsBus( port=mock_motors.port, motors=dummy_motors, ) motors_bus.connect(assert_motors_exist=False) - read_positions = motors_bus.sync_read("Present_Position", names, normalize=False) - - assert mock_motors.stubs[stub_name].called - assert read_positions == expected_positions - - -@pytest.mark.parametrize( - "num_retry, num_invalid_try, pos", - [ - (0, 2, 1337), - (2, 3, 42), - (3, 2, 4016), - (2, 1, 999), - ], -) -def test_sync_read_num_retry(num_retry, num_invalid_try, pos, mock_motors, dummy_motors): - expected_position = {"dummy_1": pos} - stub_name = mock_motors.build_sync_read_stub( - "Present_Position", {1: pos}, num_invalid_try=num_invalid_try - ) - motors_bus = DynamixelMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) - - if num_retry >= num_invalid_try: - pos_dict = motors_bus.sync_read("Present_Position", "dummy_1", normalize=False, num_retry=num_retry) - assert pos_dict == expected_position + if raise_on_error: + with pytest.raises( + RuntimeError, match=re.escape("[RxPacketError] The data value exceeds the limit value!") + ): + motors_bus._read(addr, length, id_, raise_on_error=raise_on_error) else: - with pytest.raises(ConnectionError): - _ = motors_bus.sync_read("Present_Position", "dummy_1", normalize=False, num_retry=num_retry) + _, _, read_error = motors_bus._read(addr, length, id_, raise_on_error=raise_on_error) + assert read_error == error - expected_calls = min(1 + num_retry, 1 + num_invalid_try) - assert mock_motors.stubs[stub_name].calls == expected_calls + assert mock_motors.stubs[stub].called -@pytest.mark.parametrize( - "data_name, value", - [ - ("Torque_Enable", 0), - ("Torque_Enable", 1), - ("Goal_Position", 1337), - ("Goal_Position", 42), - ], -) -def test_sync_write_single_value(data_name, value, mock_motors, dummy_motors): - ids_values = {m.id: value for m in dummy_motors.values()} - stub_name = mock_motors.build_sync_write_stub(data_name, ids_values) +@pytest.mark.parametrize("raise_on_error", (True, False)) +def test__read_comm(raise_on_error, mock_motors, dummy_motors): + addr, length, id_, value = (10, 4, 1, 1337) + stub = mock_motors.build_read_stub(addr, length, id_, value, reply=False) motors_bus = DynamixelMotorsBus( port=mock_motors.port, motors=dummy_motors, ) motors_bus.connect(assert_motors_exist=False) - motors_bus.sync_write(data_name, value, normalize=False) + if raise_on_error: + with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): + motors_bus._read(addr, length, id_, raise_on_error=raise_on_error) + else: + _, read_comm, _ = motors_bus._read(addr, length, id_, raise_on_error=raise_on_error) + assert read_comm == dxl.COMM_RX_TIMEOUT - assert mock_motors.stubs[stub_name].wait_called() + assert mock_motors.stubs[stub].called @pytest.mark.parametrize( - "ids, positions", + "addr, length, id_, value", [ - ([1], [1337]), - ([1, 2], [1337, 42]), - ([1, 2, 3], [1337, 42, 4016]), + (0, 1, 1, 2), + (10, 2, 2, 999), + (42, 4, 3, 1337), + ], +) +def test__write(addr, length, id_, value, mock_motors, dummy_motors): + stub = mock_motors.build_write_stub(addr, length, id_, value) + motors_bus = DynamixelMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + motors_bus.connect(assert_motors_exist=False) + + comm, error = motors_bus._write(addr, length, id_, value) + + assert mock_motors.stubs[stub].called + assert comm == dxl.COMM_SUCCESS + assert error == 0 + + +@pytest.mark.parametrize("raise_on_error", (True, False)) +def test__write_error(raise_on_error, mock_motors, dummy_motors): + addr, length, id_, value, error = (10, 4, 1, 1337, dxl.ERRNUM_DATA_LIMIT) + stub = mock_motors.build_write_stub(addr, length, id_, value, error=error) + motors_bus = DynamixelMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + motors_bus.connect(assert_motors_exist=False) + + if raise_on_error: + with pytest.raises( + RuntimeError, match=re.escape("[RxPacketError] The data value exceeds the limit value!") + ): + motors_bus._write(addr, length, id_, value, raise_on_error=raise_on_error) + else: + _, write_error = motors_bus._write(addr, length, id_, value, raise_on_error=raise_on_error) + assert write_error == error + + assert mock_motors.stubs[stub].called + + +@pytest.mark.parametrize("raise_on_error", (True, False)) +def test__write_comm(raise_on_error, mock_motors, dummy_motors): + addr, length, id_, value = (10, 4, 1, 1337) + stub = mock_motors.build_write_stub(addr, length, id_, value, reply=False) + motors_bus = DynamixelMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + motors_bus.connect(assert_motors_exist=False) + + if raise_on_error: + with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): + motors_bus._write(addr, length, id_, value, raise_on_error=raise_on_error) + else: + write_comm, _ = motors_bus._write(addr, length, id_, value, raise_on_error=raise_on_error) + assert write_comm == dxl.COMM_RX_TIMEOUT + + assert mock_motors.stubs[stub].called + + +@pytest.mark.parametrize( + "addr, length, ids_values", + [ + (0, 1, {1: 4}), + (10, 2, {1: 1337, 2: 42}), + (42, 4, {1: 1337, 2: 42, 3: 4016}), ], ids=["1 motor", "2 motors", "3 motors"], -) # fmt: skip -def test_sync_write(ids, positions, mock_motors, dummy_motors): - assert len(ids) == len(positions) - ids_values = dict(zip(ids, positions, strict=True)) - stub_name = mock_motors.build_sync_write_stub("Goal_Position", ids_values) +) +def test__sync_read(addr, length, ids_values, mock_motors, dummy_motors): + stub = mock_motors.build_sync_read_stub(addr, length, ids_values) motors_bus = DynamixelMotorsBus( port=mock_motors.port, motors=dummy_motors, ) motors_bus.connect(assert_motors_exist=False) - write_values = {f"dummy_{id_}": pos for id_, pos in ids_values.items()} - motors_bus.sync_write("Goal_Position", write_values, normalize=False) + read_values, _ = motors_bus._sync_read(addr, length, list(ids_values)) - assert mock_motors.stubs[stub_name].wait_called() + assert mock_motors.stubs[stub].called + assert read_values == ids_values + + +@pytest.mark.parametrize("raise_on_error", (True, False)) +def test__sync_read_comm(raise_on_error, mock_motors, dummy_motors): + addr, length, ids_values = (10, 4, {1: 1337}) + stub = mock_motors.build_sync_read_stub(addr, length, ids_values, reply=False) + motors_bus = DynamixelMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + motors_bus.connect(assert_motors_exist=False) + + if raise_on_error: + with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): + motors_bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error) + else: + _, read_comm = motors_bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error) + assert read_comm == dxl.COMM_RX_TIMEOUT + + assert mock_motors.stubs[stub].called @pytest.mark.parametrize( - "data_name, dxl_id, value", + "addr, length, ids_values", [ - ("Torque_Enable", 1, 0), - ("Torque_Enable", 1, 1), - ("Goal_Position", 2, 1337), - ("Goal_Position", 3, 42), + (0, 1, {1: 4}), + (10, 2, {1: 1337, 2: 42}), + (42, 4, {1: 1337, 2: 42, 3: 4016}), ], + ids=["1 motor", "2 motors", "3 motors"], ) -def test_write(data_name, dxl_id, value, mock_motors, dummy_motors): - stub_name = mock_motors.build_write_stub(data_name, dxl_id, value) +def test__sync_write(addr, length, ids_values, mock_motors, dummy_motors): + stub = mock_motors.build_sync_write_stub(addr, length, ids_values) motors_bus = DynamixelMotorsBus( port=mock_motors.port, motors=dummy_motors, ) motors_bus.connect(assert_motors_exist=False) - motors_bus.write(data_name, f"dummy_{dxl_id}", value, normalize=False) + comm = motors_bus._sync_write(addr, length, ids_values) - assert mock_motors.stubs[stub_name].called + assert mock_motors.stubs[stub].wait_called() + assert comm == dxl.COMM_SUCCESS def test_is_calibrated(mock_motors, dummy_motors, dummy_calibration): @@ -292,10 +318,10 @@ def test_is_calibrated(mock_motors, dummy_motors, dummy_calibration): encoded_homings = {m.id: encode_twos_complement(m.homing_offset, 4) for m in dummy_calibration.values()} mins = {m.id: m.range_min for m in dummy_calibration.values()} maxes = {m.id: m.range_max for m in dummy_calibration.values()} - drive_modes_stub = mock_motors.build_sync_read_stub("Drive_Mode", drive_modes) - offsets_stub = mock_motors.build_sync_read_stub("Homing_Offset", encoded_homings) - mins_stub = mock_motors.build_sync_read_stub("Min_Position_Limit", mins) - maxes_stub = mock_motors.build_sync_read_stub("Max_Position_Limit", maxes) + drive_modes_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Drive_Mode"], drive_modes) + offsets_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Homing_Offset"], encoded_homings) + mins_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Min_Position_Limit"], mins) + maxes_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Max_Position_Limit"], maxes) motors_bus = DynamixelMotorsBus( port=mock_motors.port, motors=dummy_motors, @@ -317,9 +343,15 @@ def test_reset_calibration(mock_motors, dummy_motors): write_mins_stubs = [] write_maxes_stubs = [] for motor in dummy_motors.values(): - write_homing_stubs.append(mock_motors.build_write_stub("Homing_Offset", motor.id, 0)) - write_mins_stubs.append(mock_motors.build_write_stub("Min_Position_Limit", motor.id, 0)) - write_maxes_stubs.append(mock_motors.build_write_stub("Max_Position_Limit", motor.id, 4095)) + write_homing_stubs.append( + mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Homing_Offset"], motor.id, 0) + ) + write_mins_stubs.append( + mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Min_Position_Limit"], motor.id, 0) + ) + write_maxes_stubs.append( + mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Max_Position_Limit"], motor.id, 4095) + ) motors_bus = DynamixelMotorsBus( port=mock_motors.port, @@ -349,11 +381,13 @@ def test_set_half_turn_homings(mock_motors, dummy_motors): 2: 2005, # 2047 - 42 3: -1625, # 2047 - 3672 } - read_pos_stub = mock_motors.build_sync_read_stub("Present_Position", current_positions) + read_pos_stub = mock_motors.build_sync_read_stub( + *X_SERIES_CONTROL_TABLE["Present_Position"], current_positions + ) write_homing_stubs = [] for id_, homing in expected_homings.items(): encoded_homing = encode_twos_complement(homing, 4) - stub = mock_motors.build_write_stub("Homing_Offset", id_, encoded_homing) + stub = mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Homing_Offset"], id_, encoded_homing) write_homing_stubs.append(stub) motors_bus = DynamixelMotorsBus( @@ -386,7 +420,9 @@ def test_record_ranges_of_motion(mock_motors, dummy_motors): "dummy_2": 3600, "dummy_3": 4002, } - read_pos_stub = mock_motors.build_sequential_sync_read_stub("Present_Position", positions) + read_pos_stub = mock_motors.build_sequential_sync_read_stub( + *X_SERIES_CONTROL_TABLE["Present_Position"], positions + ) with patch("lerobot.common.motors.motors_bus.enter_pressed", side_effect=[False, True]): motors_bus = DynamixelMotorsBus( port=mock_motors.port, From 1f210bc8a37c15c213878690e64e59783ee93999 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Mon, 14 Apr 2025 15:26:29 +0200 Subject: [PATCH 19/23] Refactor tests --- tests/mocks/mock_dynamixel.py | 3 - tests/mocks/mock_feetech.py | 3 - tests/motors/test_dynamixel.py | 146 +++++++++----------------- tests/motors/test_feetech.py | 185 +++++++++++++-------------------- 4 files changed, 124 insertions(+), 213 deletions(-) diff --git a/tests/mocks/mock_dynamixel.py b/tests/mocks/mock_dynamixel.py index a9d434e95..6f78400d7 100644 --- a/tests/mocks/mock_dynamixel.py +++ b/tests/mocks/mock_dynamixel.py @@ -5,7 +5,6 @@ import dynamixel_sdk as dxl import serial from mock_serial.mock_serial import MockSerial -from lerobot.common.motors.dynamixel import X_SERIES_CONTROL_TABLE from lerobot.common.motors.dynamixel.dynamixel import _split_into_byte_chunks from .mock_serial_patch import WaitableStub @@ -425,8 +424,6 @@ class MockMotors(MockSerial): instruction packets. It is meant to test MotorsBus classes. """ - ctrl_table = X_SERIES_CONTROL_TABLE - def __init__(self): super().__init__() diff --git a/tests/mocks/mock_feetech.py b/tests/mocks/mock_feetech.py index f4bb1c686..5948bd5eb 100644 --- a/tests/mocks/mock_feetech.py +++ b/tests/mocks/mock_feetech.py @@ -5,7 +5,6 @@ import scservo_sdk as scs import serial from mock_serial import MockSerial -from lerobot.common.motors.feetech import STS_SMS_SERIES_CONTROL_TABLE from lerobot.common.motors.feetech.feetech import _split_into_byte_chunks, patch_setPacketTimeout from .mock_serial_patch import WaitableStub @@ -278,8 +277,6 @@ class MockMotors(MockSerial): instruction packets. It is meant to test MotorsBus classes. """ - ctrl_table = STS_SMS_SERIES_CONTROL_TABLE - def __init__(self): super().__init__() diff --git a/tests/motors/test_dynamixel.py b/tests/motors/test_dynamixel.py index 163af2d16..cb2c11e69 100644 --- a/tests/motors/test_dynamixel.py +++ b/tests/motors/test_dynamixel.py @@ -90,13 +90,10 @@ def test_abc_implementation(dummy_motors): def test_ping(id_, mock_motors, dummy_motors): expected_model_nb = MODEL_NUMBER_TABLE[dummy_motors[f"dummy_{id_}"].model] stub = mock_motors.build_ping_stub(id_, expected_model_nb) - motors_bus = DynamixelMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) + bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(assert_motors_exist=False) - ping_model_nb = motors_bus.ping(id_) + ping_model_nb = bus.ping(id_) assert ping_model_nb == expected_model_nb assert mock_motors.stubs[stub].called @@ -106,13 +103,10 @@ def test_broadcast_ping(mock_motors, dummy_motors): models = {m.id: m.model for m in dummy_motors.values()} expected_model_nbs = {id_: MODEL_NUMBER_TABLE[model] for id_, model in models.items()} stub = mock_motors.build_broadcast_ping_stub(expected_model_nbs) - motors_bus = DynamixelMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) + bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(assert_motors_exist=False) - ping_model_nbs = motors_bus.broadcast_ping() + ping_model_nbs = bus.broadcast_ping() assert ping_model_nbs == expected_model_nbs assert mock_motors.stubs[stub].called @@ -128,13 +122,10 @@ def test_broadcast_ping(mock_motors, dummy_motors): ) def test__read(addr, length, id_, value, mock_motors, dummy_motors): stub = mock_motors.build_read_stub(addr, length, id_, value) - motors_bus = DynamixelMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) + bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(assert_motors_exist=False) - read_value, _, _ = motors_bus._read(addr, length, id_) + read_value, _, _ = bus._read(addr, length, id_) assert mock_motors.stubs[stub].called assert read_value == value @@ -144,19 +135,16 @@ def test__read(addr, length, id_, value, mock_motors, dummy_motors): def test__read_error(raise_on_error, mock_motors, dummy_motors): addr, length, id_, value, error = (10, 4, 1, 1337, dxl.ERRNUM_DATA_LIMIT) stub = mock_motors.build_read_stub(addr, length, id_, value, error=error) - motors_bus = DynamixelMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) + bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(assert_motors_exist=False) if raise_on_error: with pytest.raises( RuntimeError, match=re.escape("[RxPacketError] The data value exceeds the limit value!") ): - motors_bus._read(addr, length, id_, raise_on_error=raise_on_error) + bus._read(addr, length, id_, raise_on_error=raise_on_error) else: - _, _, read_error = motors_bus._read(addr, length, id_, raise_on_error=raise_on_error) + _, _, read_error = bus._read(addr, length, id_, raise_on_error=raise_on_error) assert read_error == error assert mock_motors.stubs[stub].called @@ -166,17 +154,14 @@ def test__read_error(raise_on_error, mock_motors, dummy_motors): def test__read_comm(raise_on_error, mock_motors, dummy_motors): addr, length, id_, value = (10, 4, 1, 1337) stub = mock_motors.build_read_stub(addr, length, id_, value, reply=False) - motors_bus = DynamixelMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) + bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(assert_motors_exist=False) if raise_on_error: with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): - motors_bus._read(addr, length, id_, raise_on_error=raise_on_error) + bus._read(addr, length, id_, raise_on_error=raise_on_error) else: - _, read_comm, _ = motors_bus._read(addr, length, id_, raise_on_error=raise_on_error) + _, read_comm, _ = bus._read(addr, length, id_, raise_on_error=raise_on_error) assert read_comm == dxl.COMM_RX_TIMEOUT assert mock_motors.stubs[stub].called @@ -192,13 +177,10 @@ def test__read_comm(raise_on_error, mock_motors, dummy_motors): ) def test__write(addr, length, id_, value, mock_motors, dummy_motors): stub = mock_motors.build_write_stub(addr, length, id_, value) - motors_bus = DynamixelMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) + bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(assert_motors_exist=False) - comm, error = motors_bus._write(addr, length, id_, value) + comm, error = bus._write(addr, length, id_, value) assert mock_motors.stubs[stub].called assert comm == dxl.COMM_SUCCESS @@ -209,19 +191,16 @@ def test__write(addr, length, id_, value, mock_motors, dummy_motors): def test__write_error(raise_on_error, mock_motors, dummy_motors): addr, length, id_, value, error = (10, 4, 1, 1337, dxl.ERRNUM_DATA_LIMIT) stub = mock_motors.build_write_stub(addr, length, id_, value, error=error) - motors_bus = DynamixelMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) + bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(assert_motors_exist=False) if raise_on_error: with pytest.raises( RuntimeError, match=re.escape("[RxPacketError] The data value exceeds the limit value!") ): - motors_bus._write(addr, length, id_, value, raise_on_error=raise_on_error) + bus._write(addr, length, id_, value, raise_on_error=raise_on_error) else: - _, write_error = motors_bus._write(addr, length, id_, value, raise_on_error=raise_on_error) + _, write_error = bus._write(addr, length, id_, value, raise_on_error=raise_on_error) assert write_error == error assert mock_motors.stubs[stub].called @@ -231,17 +210,14 @@ def test__write_error(raise_on_error, mock_motors, dummy_motors): def test__write_comm(raise_on_error, mock_motors, dummy_motors): addr, length, id_, value = (10, 4, 1, 1337) stub = mock_motors.build_write_stub(addr, length, id_, value, reply=False) - motors_bus = DynamixelMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) + bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(assert_motors_exist=False) if raise_on_error: with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): - motors_bus._write(addr, length, id_, value, raise_on_error=raise_on_error) + bus._write(addr, length, id_, value, raise_on_error=raise_on_error) else: - write_comm, _ = motors_bus._write(addr, length, id_, value, raise_on_error=raise_on_error) + write_comm, _ = bus._write(addr, length, id_, value, raise_on_error=raise_on_error) assert write_comm == dxl.COMM_RX_TIMEOUT assert mock_motors.stubs[stub].called @@ -258,13 +234,10 @@ def test__write_comm(raise_on_error, mock_motors, dummy_motors): ) def test__sync_read(addr, length, ids_values, mock_motors, dummy_motors): stub = mock_motors.build_sync_read_stub(addr, length, ids_values) - motors_bus = DynamixelMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) + bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(assert_motors_exist=False) - read_values, _ = motors_bus._sync_read(addr, length, list(ids_values)) + read_values, _ = bus._sync_read(addr, length, list(ids_values)) assert mock_motors.stubs[stub].called assert read_values == ids_values @@ -274,17 +247,14 @@ def test__sync_read(addr, length, ids_values, mock_motors, dummy_motors): def test__sync_read_comm(raise_on_error, mock_motors, dummy_motors): addr, length, ids_values = (10, 4, {1: 1337}) stub = mock_motors.build_sync_read_stub(addr, length, ids_values, reply=False) - motors_bus = DynamixelMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) + bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(assert_motors_exist=False) if raise_on_error: with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): - motors_bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error) + bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error) else: - _, read_comm = motors_bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error) + _, read_comm = bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error) assert read_comm == dxl.COMM_RX_TIMEOUT assert mock_motors.stubs[stub].called @@ -301,13 +271,10 @@ def test__sync_read_comm(raise_on_error, mock_motors, dummy_motors): ) def test__sync_write(addr, length, ids_values, mock_motors, dummy_motors): stub = mock_motors.build_sync_write_stub(addr, length, ids_values) - motors_bus = DynamixelMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) + bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(assert_motors_exist=False) - comm = motors_bus._sync_write(addr, length, ids_values) + comm = bus._sync_write(addr, length, ids_values) assert mock_motors.stubs[stub].wait_called() assert comm == dxl.COMM_SUCCESS @@ -322,14 +289,14 @@ def test_is_calibrated(mock_motors, dummy_motors, dummy_calibration): offsets_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Homing_Offset"], encoded_homings) mins_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Min_Position_Limit"], mins) maxes_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Max_Position_Limit"], maxes) - motors_bus = DynamixelMotorsBus( + bus = DynamixelMotorsBus( port=mock_motors.port, motors=dummy_motors, calibration=dummy_calibration, ) - motors_bus.connect(assert_motors_exist=False) + bus.connect(assert_motors_exist=False) - is_calibrated = motors_bus.is_calibrated + is_calibrated = bus.is_calibrated assert is_calibrated assert mock_motors.stubs[drive_modes_stub].called @@ -353,13 +320,10 @@ def test_reset_calibration(mock_motors, dummy_motors): mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Max_Position_Limit"], motor.id, 4095) ) - motors_bus = DynamixelMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) + bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(assert_motors_exist=False) - motors_bus.reset_calibration() + bus.reset_calibration() assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs) assert all(mock_motors.stubs[stub].called for stub in write_mins_stubs) @@ -390,16 +354,13 @@ def test_set_half_turn_homings(mock_motors, dummy_motors): stub = mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Homing_Offset"], id_, encoded_homing) write_homing_stubs.append(stub) - motors_bus = DynamixelMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) - motors_bus.reset_calibration = MagicMock() + bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(assert_motors_exist=False) + bus.reset_calibration = MagicMock() - motors_bus.set_half_turn_homings() + bus.set_half_turn_homings() - motors_bus.reset_calibration.assert_called_once() + bus.reset_calibration.assert_called_once() assert mock_motors.stubs[read_pos_stub].called assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs) @@ -424,13 +385,10 @@ def test_record_ranges_of_motion(mock_motors, dummy_motors): *X_SERIES_CONTROL_TABLE["Present_Position"], positions ) with patch("lerobot.common.motors.motors_bus.enter_pressed", side_effect=[False, True]): - motors_bus = DynamixelMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) + bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(assert_motors_exist=False) - mins, maxes = motors_bus.record_ranges_of_motion(display_values=False) + mins, maxes = bus.record_ranges_of_motion(display_values=False) assert mock_motors.stubs[read_pos_stub].calls == 3 assert mins == expected_mins diff --git a/tests/motors/test_feetech.py b/tests/motors/test_feetech.py index d25b98bc6..baf6d3407 100644 --- a/tests/motors/test_feetech.py +++ b/tests/motors/test_feetech.py @@ -91,36 +91,19 @@ def test_abc_implementation(dummy_motors): FeetechMotorsBus(port="/dev/dummy-port", motors=dummy_motors) -@pytest.mark.skip("TODO") -def test_scan_port(mock_motors): - expected = { - 9_600: {1: 777}, - 57_600: {2: 777}, - 500_000: {237: 777}, - } - expected_model_nbs = {id_: model for d in expected.values() for id_, model in d.items()} - ping_stub = mock_motors.build_broadcast_ping_stub(list(expected_model_nbs)) - mobel_nb_stub = mock_motors.build_sync_read_stub("Model_Number", expected_model_nbs) - found = FeetechMotorsBus.scan_port(mock_motors.port) - - assert found == expected - assert mock_motors.stubs[ping_stub].called - assert mock_motors.stubs[mobel_nb_stub].called - - @pytest.mark.parametrize("id_", [1, 2, 3]) def test_ping(id_, mock_motors, dummy_motors): expected_model_nb = MODEL_NUMBER_TABLE[dummy_motors[f"dummy_{id_}"].model] addr, length = MODEL_NUMBER ping_stub = mock_motors.build_ping_stub(id_) mobel_nb_stub = mock_motors.build_read_stub(addr, length, id_, expected_model_nb) - motors_bus = FeetechMotorsBus( + bus = FeetechMotorsBus( port=mock_motors.port, motors=dummy_motors, ) - motors_bus.connect(assert_motors_exist=False) + bus.connect(assert_motors_exist=False) - ping_model_nb = motors_bus.ping(id_) + ping_model_nb = bus.ping(id_) assert ping_model_nb == expected_model_nb assert mock_motors.stubs[ping_stub].called @@ -138,13 +121,13 @@ def test_broadcast_ping(mock_motors, dummy_motors): stub = mock_motors.build_read_stub(addr, length, id_, model_nb) expected_model_nbs[id_] = model_nb mobel_nb_stubs.append(stub) - motors_bus = FeetechMotorsBus( + bus = FeetechMotorsBus( port=mock_motors.port, motors=dummy_motors, ) - motors_bus.connect(assert_motors_exist=False) + bus.connect(assert_motors_exist=False) - ping_model_nbs = motors_bus.broadcast_ping() + ping_model_nbs = bus.broadcast_ping() assert ping_model_nbs == expected_model_nbs assert mock_motors.stubs[ping_stub].called @@ -160,57 +143,57 @@ def test_broadcast_ping(mock_motors, dummy_motors): ], ) def test__read(addr, length, id_, value, mock_motors, dummy_motors): - stub_name = mock_motors.build_read_stub(addr, length, id_, value) - motors_bus = FeetechMotorsBus( + stub = mock_motors.build_read_stub(addr, length, id_, value) + bus = FeetechMotorsBus( port=mock_motors.port, motors=dummy_motors, ) - motors_bus.connect(assert_motors_exist=False) + bus.connect(assert_motors_exist=False) - read_value, _, _ = motors_bus._read(addr, length, id_) + read_value, _, _ = bus._read(addr, length, id_) - assert mock_motors.stubs[stub_name].called + assert mock_motors.stubs[stub].called assert read_value == value @pytest.mark.parametrize("raise_on_error", (True, False)) def test__read_error(raise_on_error, mock_motors, dummy_motors): addr, length, id_, value, error = (10, 4, 1, 1337, scs.ERRBIT_VOLTAGE) - stub_name = mock_motors.build_read_stub(addr, length, id_, value, error=error) - motors_bus = FeetechMotorsBus( + stub = mock_motors.build_read_stub(addr, length, id_, value, error=error) + bus = FeetechMotorsBus( port=mock_motors.port, motors=dummy_motors, ) - motors_bus.connect(assert_motors_exist=False) + bus.connect(assert_motors_exist=False) if raise_on_error: with pytest.raises(RuntimeError, match=re.escape("[RxPacketError] Input voltage error!")): - motors_bus._read(addr, length, id_, raise_on_error=raise_on_error) + bus._read(addr, length, id_, raise_on_error=raise_on_error) else: - _, _, read_error = motors_bus._read(addr, length, id_, raise_on_error=raise_on_error) + _, _, read_error = bus._read(addr, length, id_, raise_on_error=raise_on_error) assert read_error == error - assert mock_motors.stubs[stub_name].called + assert mock_motors.stubs[stub].called @pytest.mark.parametrize("raise_on_error", (True, False)) def test__read_comm(raise_on_error, mock_motors, dummy_motors): addr, length, id_, value = (10, 4, 1, 1337) - stub_name = mock_motors.build_read_stub(addr, length, id_, value, reply=False) - motors_bus = FeetechMotorsBus( + stub = mock_motors.build_read_stub(addr, length, id_, value, reply=False) + bus = FeetechMotorsBus( port=mock_motors.port, motors=dummy_motors, ) - motors_bus.connect(assert_motors_exist=False) + bus.connect(assert_motors_exist=False) if raise_on_error: with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): - motors_bus._read(addr, length, id_, raise_on_error=raise_on_error) + bus._read(addr, length, id_, raise_on_error=raise_on_error) else: - _, read_comm, _ = motors_bus._read(addr, length, id_, raise_on_error=raise_on_error) + _, read_comm, _ = bus._read(addr, length, id_, raise_on_error=raise_on_error) assert read_comm == scs.COMM_RX_TIMEOUT - assert mock_motors.stubs[stub_name].called + assert mock_motors.stubs[stub].called @pytest.mark.parametrize( @@ -222,16 +205,16 @@ def test__read_comm(raise_on_error, mock_motors, dummy_motors): ], ) def test__write(addr, length, id_, value, mock_motors, dummy_motors): - stub_name = mock_motors.build_write_stub(addr, length, id_, value) - motors_bus = FeetechMotorsBus( + stub = mock_motors.build_write_stub(addr, length, id_, value) + bus = FeetechMotorsBus( port=mock_motors.port, motors=dummy_motors, ) - motors_bus.connect(assert_motors_exist=False) + bus.connect(assert_motors_exist=False) - comm, error = motors_bus._write(addr, length, id_, value) + comm, error = bus._write(addr, length, id_, value) - assert mock_motors.stubs[stub_name].called + assert mock_motors.stubs[stub].called assert comm == scs.COMM_SUCCESS assert error == 0 @@ -239,41 +222,35 @@ def test__write(addr, length, id_, value, mock_motors, dummy_motors): @pytest.mark.parametrize("raise_on_error", (True, False)) def test__write_error(raise_on_error, mock_motors, dummy_motors): addr, length, id_, value, error = (10, 4, 1, 1337, scs.ERRBIT_VOLTAGE) - stub_name = mock_motors.build_write_stub(addr, length, id_, value, error=error) - motors_bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) + stub = mock_motors.build_write_stub(addr, length, id_, value, error=error) + bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(assert_motors_exist=False) if raise_on_error: with pytest.raises(RuntimeError, match=re.escape("[RxPacketError] Input voltage error!")): - motors_bus._write(addr, length, id_, value, raise_on_error=raise_on_error) + bus._write(addr, length, id_, value, raise_on_error=raise_on_error) else: - _, write_error = motors_bus._write(addr, length, id_, value, raise_on_error=raise_on_error) + _, write_error = bus._write(addr, length, id_, value, raise_on_error=raise_on_error) assert write_error == error - assert mock_motors.stubs[stub_name].called + assert mock_motors.stubs[stub].called @pytest.mark.parametrize("raise_on_error", (True, False)) def test__write_comm(raise_on_error, mock_motors, dummy_motors): addr, length, id_, value = (10, 4, 1, 1337) - stub_name = mock_motors.build_write_stub(addr, length, id_, value, reply=False) - motors_bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) + stub = mock_motors.build_write_stub(addr, length, id_, value, reply=False) + bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(assert_motors_exist=False) if raise_on_error: with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): - motors_bus._write(addr, length, id_, value, raise_on_error=raise_on_error) + bus._write(addr, length, id_, value, raise_on_error=raise_on_error) else: - write_comm, _ = motors_bus._write(addr, length, id_, value, raise_on_error=raise_on_error) + write_comm, _ = bus._write(addr, length, id_, value, raise_on_error=raise_on_error) assert write_comm == scs.COMM_RX_TIMEOUT - assert mock_motors.stubs[stub_name].called + assert mock_motors.stubs[stub].called @pytest.mark.parametrize( @@ -286,37 +263,31 @@ def test__write_comm(raise_on_error, mock_motors, dummy_motors): ids=["1 motor", "2 motors", "3 motors"], ) def test__sync_read(addr, length, ids_values, mock_motors, dummy_motors): - stub_name = mock_motors.build_sync_read_stub(addr, length, ids_values) - motors_bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) + stub = mock_motors.build_sync_read_stub(addr, length, ids_values) + bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(assert_motors_exist=False) - read_values, _ = motors_bus._sync_read(addr, length, list(ids_values)) + read_values, _ = bus._sync_read(addr, length, list(ids_values)) - assert mock_motors.stubs[stub_name].called + assert mock_motors.stubs[stub].called assert read_values == ids_values @pytest.mark.parametrize("raise_on_error", (True, False)) def test__sync_read_comm(raise_on_error, mock_motors, dummy_motors): addr, length, ids_values = (10, 4, {1: 1337}) - stub_name = mock_motors.build_sync_read_stub(addr, length, ids_values, reply=False) - motors_bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) + stub = mock_motors.build_sync_read_stub(addr, length, ids_values, reply=False) + bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(assert_motors_exist=False) if raise_on_error: with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): - motors_bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error) + bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error) else: - _, read_comm = motors_bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error) + _, read_comm = bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error) assert read_comm == scs.COMM_RX_TIMEOUT - assert mock_motors.stubs[stub_name].called + assert mock_motors.stubs[stub].called @pytest.mark.parametrize( @@ -329,16 +300,13 @@ def test__sync_read_comm(raise_on_error, mock_motors, dummy_motors): ids=["1 motor", "2 motors", "3 motors"], ) def test__sync_write(addr, length, ids_values, mock_motors, dummy_motors): - stub_name = mock_motors.build_sync_write_stub(addr, length, ids_values) - motors_bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) + stub = mock_motors.build_sync_write_stub(addr, length, ids_values) + bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(assert_motors_exist=False) - comm = motors_bus._sync_write(addr, length, ids_values) + comm = bus._sync_write(addr, length, ids_values) - assert mock_motors.stubs[stub_name].wait_called() + assert mock_motors.stubs[stub].wait_called() assert comm == scs.COMM_SUCCESS @@ -351,14 +319,14 @@ def test_is_calibrated(mock_motors, dummy_motors, dummy_calibration): ) mins_stub = mock_motors.build_sync_read_stub(*STS_SMS_SERIES_CONTROL_TABLE["Min_Position_Limit"], mins) maxes_stub = mock_motors.build_sync_read_stub(*STS_SMS_SERIES_CONTROL_TABLE["Max_Position_Limit"], maxes) - motors_bus = FeetechMotorsBus( + bus = FeetechMotorsBus( port=mock_motors.port, motors=dummy_motors, calibration=dummy_calibration, ) - motors_bus.connect(assert_motors_exist=False) + bus.connect(assert_motors_exist=False) - is_calibrated = motors_bus.is_calibrated + is_calibrated = bus.is_calibrated assert is_calibrated assert mock_motors.stubs[offsets_stub].called @@ -381,13 +349,10 @@ def test_reset_calibration(mock_motors, dummy_motors): mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Max_Position_Limit"], motor.id, 4095) ) - motors_bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) + bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(assert_motors_exist=False) - motors_bus.reset_calibration() + bus.reset_calibration() assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs) assert all(mock_motors.stubs[stub].called for stub in write_mins_stubs) @@ -420,16 +385,13 @@ def test_set_half_turn_homings(mock_motors, dummy_motors): ) write_homing_stubs.append(stub) - motors_bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) - motors_bus.reset_calibration = MagicMock() + bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(assert_motors_exist=False) + bus.reset_calibration = MagicMock() - motors_bus.set_half_turn_homings() + bus.set_half_turn_homings() - motors_bus.reset_calibration.assert_called_once() + bus.reset_calibration.assert_called_once() assert mock_motors.stubs[read_pos_stub].called assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs) @@ -450,18 +412,15 @@ def test_record_ranges_of_motion(mock_motors, dummy_motors): "dummy_2": 3600, "dummy_3": 4002, } - read_pos_stub = mock_motors.build_sequential_sync_read_stub( + stub = mock_motors.build_sequential_sync_read_stub( *STS_SMS_SERIES_CONTROL_TABLE["Present_Position"], positions ) with patch("lerobot.common.motors.motors_bus.enter_pressed", side_effect=[False, True]): - motors_bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) + bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(assert_motors_exist=False) - mins, maxes = motors_bus.record_ranges_of_motion(display_values=False) + mins, maxes = bus.record_ranges_of_motion(display_values=False) - assert mock_motors.stubs[read_pos_stub].calls == 3 + assert mock_motors.stubs[stub].calls == 3 assert mins == expected_mins assert maxes == expected_maxes From 889de7c415afeb2e44df6b8078c8be25df8971a9 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Mon, 14 Apr 2025 17:14:06 +0200 Subject: [PATCH 20/23] Add handshake, fix feetech _read_firmware_version --- lerobot/common/motors/dynamixel/dynamixel.py | 3 ++ lerobot/common/motors/feetech/feetech.py | 21 +++++++++++--- lerobot/common/motors/motors_bus.py | 20 +++++++++---- tests/motors/test_dynamixel.py | 30 ++++++++++---------- tests/motors/test_feetech.py | 30 ++++++++++---------- tests/motors/test_motors_bus.py | 14 ++++----- 6 files changed, 71 insertions(+), 47 deletions(-) diff --git a/lerobot/common/motors/dynamixel/dynamixel.py b/lerobot/common/motors/dynamixel/dynamixel.py index 21f2524c4..4bcbf6b03 100644 --- a/lerobot/common/motors/dynamixel/dynamixel.py +++ b/lerobot/common/motors/dynamixel/dynamixel.py @@ -136,6 +136,9 @@ class DynamixelMotorsBus(MotorsBus): def _assert_protocol_is_compatible(self, instruction_name: str) -> None: pass + def _handshake(self) -> None: + self._assert_motors_exist() + def configure_motors(self) -> None: # By default, Dynamixel motors have a 500µs delay response time (corresponding to a value of 250 on # the 'Return_Delay_Time' address). We ensure this is reduced to the minimum of 2µs (value of 0). diff --git a/lerobot/common/motors/feetech/feetech.py b/lerobot/common/motors/feetech/feetech.py index 193f1b4a0..8c31401b4 100644 --- a/lerobot/common/motors/feetech/feetech.py +++ b/lerobot/common/motors/feetech/feetech.py @@ -22,6 +22,7 @@ from lerobot.common.utils.encoding_utils import decode_sign_magnitude, encode_si from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value from .tables import ( FIRMWARE_MAJOR_VERSION, + FIRMWARE_MINOR_VERSION, MODEL_BAUDRATE_TABLE, MODEL_CONTROL_TABLE, MODEL_ENCODING_TABLE, @@ -150,6 +151,18 @@ class FeetechMotorsBus(MotorsBus): "'Broadcast Ping' is not available with Feetech motors using Protocol 1. Use 'Ping' sequentially instead." ) + def _assert_same_firmware(self) -> None: + firmware_versions = self._read_firmware_version(self.ids) + if len(set(firmware_versions.values())) != 1: + raise RuntimeError( + "Some Motors use different firmware versions. Update their firmware first using Feetech's software. " + "Visit https://www.feetechrc.com/software." + ) + + def _handshake(self) -> None: + self._assert_motors_exist() + self._assert_same_firmware() + def configure_motors(self) -> None: # By default, Feetech motors have a 500µs delay response time (corresponding to a value of 250 on the # 'Return_Delay' address). We ensure this is reduced to the minimum of 2µs (value of 0). @@ -317,9 +330,9 @@ class FeetechMotorsBus(MotorsBus): display_dict = {id_: self.packet_handler.getRxPacketError(err) for id_, err in ids_errors.items()} logger.error(f"Some motors found returned an error status:\n{pformat(display_dict, indent=4)}") - return self._get_model_number(list(ids_status), raise_on_error) + return self._read_model_number(list(ids_status), raise_on_error) - def _get_firmware_version(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, str]: + def _read_firmware_version(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, str]: firmware_versions = {} for id_ in motor_ids: firm_ver_major, comm, error = self._read( @@ -329,7 +342,7 @@ class FeetechMotorsBus(MotorsBus): return firm_ver_minor, comm, error = self._read( - *FIRMWARE_MAJOR_VERSION, id_, raise_on_error=raise_on_error + *FIRMWARE_MINOR_VERSION, id_, raise_on_error=raise_on_error ) if not self._is_comm_success(comm) or self._is_error(error): return @@ -338,7 +351,7 @@ class FeetechMotorsBus(MotorsBus): return firmware_versions - def _get_model_number(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]: + def _read_model_number(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]: model_numbers = {} for id_ in motor_ids: model_nb, comm, error = self._read(*MODEL_NUMBER, id_, raise_on_error=raise_on_error) diff --git a/lerobot/common/motors/motors_bus.py b/lerobot/common/motors/motors_bus.py index d0f8ff3ed..1ec4a201d 100644 --- a/lerobot/common/motors/motors_bus.py +++ b/lerobot/common/motors/motors_bus.py @@ -377,9 +377,13 @@ class MotorsBus(abc.ABC): def _assert_motors_exist(self) -> None: # TODO(aliberts): collect all wrong ids/models and display them at once - found_models = self.broadcast_ping() + found_models = {} + for id_ in self.ids: + model_nb = self.ping(id_) + if model_nb is not None: + found_models[id_] = model_nb expected_models = {m.id: self.model_number_table[m.model] for m in self.motors.values()} - if not found_models or set(found_models) != set(self.ids): + if set(found_models) != set(self.ids): raise RuntimeError( f"{self.__class__.__name__} is supposed to have these motors: ({{id: model_nb}})" f"\n{pformat(expected_models, indent=4, sort_dicts=False)}\n" @@ -403,7 +407,7 @@ class MotorsBus(abc.ABC): def is_connected(self) -> bool: return self.port_handler.is_open - def connect(self, assert_motors_exist: bool = True) -> None: + def connect(self, handshake: bool = True) -> None: if self.is_connected: raise DeviceAlreadyConnectedError( f"{self.__class__.__name__}('{self.port}') is already connected. Do not call `{self.__class__.__name__}.connect()` twice." @@ -412,8 +416,8 @@ class MotorsBus(abc.ABC): try: if not self.port_handler.openPort(): raise OSError(f"Failed to open port '{self.port}'.") - elif assert_motors_exist: - self._assert_motors_exist() + elif handshake: + self._handshake() except (FileNotFoundError, OSError, serial.SerialException) as e: raise ConnectionError( f"\nCould not connect on port '{self.port}'. Make sure you are using the correct port." @@ -423,6 +427,10 @@ class MotorsBus(abc.ABC): self.set_timeout() logger.debug(f"{self.__class__.__name__} connected.") + @abc.abstractmethod + def _handshake(self) -> None: + pass + @classmethod def scan_port(cls, port: str, *args, **kwargs) -> dict[int, list[int]]: bus = cls(port, {}, *args, **kwargs) @@ -690,7 +698,7 @@ class MotorsBus(abc.ABC): return if self._is_error(error): if raise_on_error: - raise RuntimeError(self.packet_handler.getTxRxResult(comm)) + raise RuntimeError(self.packet_handler.getRxPacketError(error)) else: return diff --git a/tests/motors/test_dynamixel.py b/tests/motors/test_dynamixel.py index cb2c11e69..822fd0493 100644 --- a/tests/motors/test_dynamixel.py +++ b/tests/motors/test_dynamixel.py @@ -91,7 +91,7 @@ def test_ping(id_, mock_motors, dummy_motors): expected_model_nb = MODEL_NUMBER_TABLE[dummy_motors[f"dummy_{id_}"].model] stub = mock_motors.build_ping_stub(id_, expected_model_nb) bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) ping_model_nb = bus.ping(id_) @@ -104,7 +104,7 @@ def test_broadcast_ping(mock_motors, dummy_motors): expected_model_nbs = {id_: MODEL_NUMBER_TABLE[model] for id_, model in models.items()} stub = mock_motors.build_broadcast_ping_stub(expected_model_nbs) bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) ping_model_nbs = bus.broadcast_ping() @@ -123,7 +123,7 @@ def test_broadcast_ping(mock_motors, dummy_motors): def test__read(addr, length, id_, value, mock_motors, dummy_motors): stub = mock_motors.build_read_stub(addr, length, id_, value) bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) read_value, _, _ = bus._read(addr, length, id_) @@ -136,7 +136,7 @@ def test__read_error(raise_on_error, mock_motors, dummy_motors): addr, length, id_, value, error = (10, 4, 1, 1337, dxl.ERRNUM_DATA_LIMIT) stub = mock_motors.build_read_stub(addr, length, id_, value, error=error) bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) if raise_on_error: with pytest.raises( @@ -155,7 +155,7 @@ def test__read_comm(raise_on_error, mock_motors, dummy_motors): addr, length, id_, value = (10, 4, 1, 1337) stub = mock_motors.build_read_stub(addr, length, id_, value, reply=False) bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) if raise_on_error: with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): @@ -178,7 +178,7 @@ def test__read_comm(raise_on_error, mock_motors, dummy_motors): def test__write(addr, length, id_, value, mock_motors, dummy_motors): stub = mock_motors.build_write_stub(addr, length, id_, value) bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) comm, error = bus._write(addr, length, id_, value) @@ -192,7 +192,7 @@ def test__write_error(raise_on_error, mock_motors, dummy_motors): addr, length, id_, value, error = (10, 4, 1, 1337, dxl.ERRNUM_DATA_LIMIT) stub = mock_motors.build_write_stub(addr, length, id_, value, error=error) bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) if raise_on_error: with pytest.raises( @@ -211,7 +211,7 @@ def test__write_comm(raise_on_error, mock_motors, dummy_motors): addr, length, id_, value = (10, 4, 1, 1337) stub = mock_motors.build_write_stub(addr, length, id_, value, reply=False) bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) if raise_on_error: with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): @@ -235,7 +235,7 @@ def test__write_comm(raise_on_error, mock_motors, dummy_motors): def test__sync_read(addr, length, ids_values, mock_motors, dummy_motors): stub = mock_motors.build_sync_read_stub(addr, length, ids_values) bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) read_values, _ = bus._sync_read(addr, length, list(ids_values)) @@ -248,7 +248,7 @@ def test__sync_read_comm(raise_on_error, mock_motors, dummy_motors): addr, length, ids_values = (10, 4, {1: 1337}) stub = mock_motors.build_sync_read_stub(addr, length, ids_values, reply=False) bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) if raise_on_error: with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): @@ -272,7 +272,7 @@ def test__sync_read_comm(raise_on_error, mock_motors, dummy_motors): def test__sync_write(addr, length, ids_values, mock_motors, dummy_motors): stub = mock_motors.build_sync_write_stub(addr, length, ids_values) bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) comm = bus._sync_write(addr, length, ids_values) @@ -294,7 +294,7 @@ def test_is_calibrated(mock_motors, dummy_motors, dummy_calibration): motors=dummy_motors, calibration=dummy_calibration, ) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) is_calibrated = bus.is_calibrated @@ -321,7 +321,7 @@ def test_reset_calibration(mock_motors, dummy_motors): ) bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) bus.reset_calibration() @@ -355,7 +355,7 @@ def test_set_half_turn_homings(mock_motors, dummy_motors): write_homing_stubs.append(stub) bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) bus.reset_calibration = MagicMock() bus.set_half_turn_homings() @@ -386,7 +386,7 @@ def test_record_ranges_of_motion(mock_motors, dummy_motors): ) with patch("lerobot.common.motors.motors_bus.enter_pressed", side_effect=[False, True]): bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) mins, maxes = bus.record_ranges_of_motion(display_values=False) diff --git a/tests/motors/test_feetech.py b/tests/motors/test_feetech.py index baf6d3407..360c13cbd 100644 --- a/tests/motors/test_feetech.py +++ b/tests/motors/test_feetech.py @@ -101,7 +101,7 @@ def test_ping(id_, mock_motors, dummy_motors): port=mock_motors.port, motors=dummy_motors, ) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) ping_model_nb = bus.ping(id_) @@ -125,7 +125,7 @@ def test_broadcast_ping(mock_motors, dummy_motors): port=mock_motors.port, motors=dummy_motors, ) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) ping_model_nbs = bus.broadcast_ping() @@ -148,7 +148,7 @@ def test__read(addr, length, id_, value, mock_motors, dummy_motors): port=mock_motors.port, motors=dummy_motors, ) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) read_value, _, _ = bus._read(addr, length, id_) @@ -164,7 +164,7 @@ def test__read_error(raise_on_error, mock_motors, dummy_motors): port=mock_motors.port, motors=dummy_motors, ) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) if raise_on_error: with pytest.raises(RuntimeError, match=re.escape("[RxPacketError] Input voltage error!")): @@ -184,7 +184,7 @@ def test__read_comm(raise_on_error, mock_motors, dummy_motors): port=mock_motors.port, motors=dummy_motors, ) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) if raise_on_error: with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): @@ -210,7 +210,7 @@ def test__write(addr, length, id_, value, mock_motors, dummy_motors): port=mock_motors.port, motors=dummy_motors, ) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) comm, error = bus._write(addr, length, id_, value) @@ -224,7 +224,7 @@ def test__write_error(raise_on_error, mock_motors, dummy_motors): addr, length, id_, value, error = (10, 4, 1, 1337, scs.ERRBIT_VOLTAGE) stub = mock_motors.build_write_stub(addr, length, id_, value, error=error) bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) if raise_on_error: with pytest.raises(RuntimeError, match=re.escape("[RxPacketError] Input voltage error!")): @@ -241,7 +241,7 @@ def test__write_comm(raise_on_error, mock_motors, dummy_motors): addr, length, id_, value = (10, 4, 1, 1337) stub = mock_motors.build_write_stub(addr, length, id_, value, reply=False) bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) if raise_on_error: with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): @@ -265,7 +265,7 @@ def test__write_comm(raise_on_error, mock_motors, dummy_motors): def test__sync_read(addr, length, ids_values, mock_motors, dummy_motors): stub = mock_motors.build_sync_read_stub(addr, length, ids_values) bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) read_values, _ = bus._sync_read(addr, length, list(ids_values)) @@ -278,7 +278,7 @@ def test__sync_read_comm(raise_on_error, mock_motors, dummy_motors): addr, length, ids_values = (10, 4, {1: 1337}) stub = mock_motors.build_sync_read_stub(addr, length, ids_values, reply=False) bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) if raise_on_error: with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): @@ -302,7 +302,7 @@ def test__sync_read_comm(raise_on_error, mock_motors, dummy_motors): def test__sync_write(addr, length, ids_values, mock_motors, dummy_motors): stub = mock_motors.build_sync_write_stub(addr, length, ids_values) bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) comm = bus._sync_write(addr, length, ids_values) @@ -324,7 +324,7 @@ def test_is_calibrated(mock_motors, dummy_motors, dummy_calibration): motors=dummy_motors, calibration=dummy_calibration, ) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) is_calibrated = bus.is_calibrated @@ -350,7 +350,7 @@ def test_reset_calibration(mock_motors, dummy_motors): ) bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) bus.reset_calibration() @@ -386,7 +386,7 @@ def test_set_half_turn_homings(mock_motors, dummy_motors): write_homing_stubs.append(stub) bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) bus.reset_calibration = MagicMock() bus.set_half_turn_homings() @@ -417,7 +417,7 @@ def test_record_ranges_of_motion(mock_motors, dummy_motors): ) with patch("lerobot.common.motors.motors_bus.enter_pressed", side_effect=[False, True]): bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) mins, maxes = bus.record_ranges_of_motion(display_values=False) diff --git a/tests/motors/test_motors_bus.py b/tests/motors/test_motors_bus.py index c98cda7dd..f3af8daf8 100644 --- a/tests/motors/test_motors_bus.py +++ b/tests/motors/test_motors_bus.py @@ -227,7 +227,7 @@ def test__serialize_data_large_number(): ) def test_read(data_name, id_, value, dummy_motors): bus = MockMotorsBus("/dev/dummy-port", dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) addr, length = DUMMY_CTRL_TABLE_2[data_name] with ( @@ -261,7 +261,7 @@ def test_read(data_name, id_, value, dummy_motors): ) def test_write(data_name, id_, value, dummy_motors): bus = MockMotorsBus("/dev/dummy-port", dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) addr, length = DUMMY_CTRL_TABLE_2[data_name] with ( @@ -296,7 +296,7 @@ def test_write(data_name, id_, value, dummy_motors): ) def test_sync_read_by_str(data_name, id_, value, dummy_motors): bus = MockMotorsBus("/dev/dummy-port", dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) addr, length = DUMMY_CTRL_TABLE_2[data_name] ids = [id_] expected_value = {f"dummy_{id_}": value} @@ -333,7 +333,7 @@ def test_sync_read_by_str(data_name, id_, value, dummy_motors): ) def test_sync_read_by_list(data_name, ids_values, dummy_motors): bus = MockMotorsBus("/dev/dummy-port", dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) addr, length = DUMMY_CTRL_TABLE_2[data_name] ids = list(ids_values) expected_values = {f"dummy_{id_}": val for id_, val in ids_values.items()} @@ -370,7 +370,7 @@ def test_sync_read_by_list(data_name, ids_values, dummy_motors): ) def test_sync_read_by_none(data_name, ids_values, dummy_motors): bus = MockMotorsBus("/dev/dummy-port", dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) addr, length = DUMMY_CTRL_TABLE_2[data_name] ids = list(ids_values) expected_values = {f"dummy_{id_}": val for id_, val in ids_values.items()} @@ -406,7 +406,7 @@ def test_sync_read_by_none(data_name, ids_values, dummy_motors): ) def test_sync_write_by_single_value(data_name, value, dummy_motors): bus = MockMotorsBus("/dev/dummy-port", dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) addr, length = DUMMY_CTRL_TABLE_2[data_name] ids_values = {m.id: value for m in dummy_motors.values()} @@ -441,7 +441,7 @@ def test_sync_write_by_single_value(data_name, value, dummy_motors): ) def test_sync_write_by_value_dict(data_name, ids_values, dummy_motors): bus = MockMotorsBus("/dev/dummy-port", dummy_motors) - bus.connect(assert_motors_exist=False) + bus.connect(handshake=False) addr, length = DUMMY_CTRL_TABLE_2[data_name] values = {f"dummy_{id_}": val for id_, val in ids_values.items()} From f71e224023ed30412981a3fc4649fef4b0082e61 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Tue, 15 Apr 2025 11:18:44 +0200 Subject: [PATCH 21/23] Fix tests --- tests/motors/test_motors_bus.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/motors/test_motors_bus.py b/tests/motors/test_motors_bus.py index f3af8daf8..879a8c81b 100644 --- a/tests/motors/test_motors_bus.py +++ b/tests/motors/test_motors_bus.py @@ -131,6 +131,7 @@ class MockMotorsBus(MotorsBus): self.port_handler = MockPortHandler(port) def _assert_protocol_is_compatible(self, instruction_name): ... + def _handshake(self): ... def configure_motors(self): ... def disable_torque(self, motors): ... def enable_torque(self, motors): ... From 9afc4b771c62b80be2cf004a047b09add1e14cc5 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Tue, 15 Apr 2025 11:20:42 +0200 Subject: [PATCH 22/23] Motors config & disconnect fixes --- lerobot/common/motors/dynamixel/dynamixel.py | 12 +++++----- lerobot/common/motors/feetech/feetech.py | 24 +++++++++++-------- lerobot/common/motors/motors_bus.py | 6 ++--- lerobot/common/robots/so100/so100_follower.py | 5 +--- .../teleoperators/so100/so100_leader.py | 1 + 5 files changed, 25 insertions(+), 23 deletions(-) diff --git a/lerobot/common/motors/dynamixel/dynamixel.py b/lerobot/common/motors/dynamixel/dynamixel.py index 4bcbf6b03..52a84e5ed 100644 --- a/lerobot/common/motors/dynamixel/dynamixel.py +++ b/lerobot/common/motors/dynamixel/dynamixel.py @@ -142,16 +142,16 @@ class DynamixelMotorsBus(MotorsBus): def configure_motors(self) -> None: # By default, Dynamixel motors have a 500µs delay response time (corresponding to a value of 250 on # the 'Return_Delay_Time' address). We ensure this is reduced to the minimum of 2µs (value of 0). - for id_ in self.ids: - self.write("Return_Delay_Time", id_, 0) + for motor in self.motors: + self.write("Return_Delay_Time", motor, 0) - def disable_torque(self, motors: str | list[str] | None = None) -> None: + def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: for name in self._get_motors_list(motors): - self.write("Torque_Enable", name, TorqueMode.DISABLED.value) + self.write("Torque_Enable", name, TorqueMode.DISABLED.value, num_retry=num_retry) - def enable_torque(self, motors: str | list[str] | None = None) -> None: + def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: for name in self._get_motors_list(motors): - self.write("Torque_Enable", name, TorqueMode.ENABLED.value) + self.write("Torque_Enable", name, TorqueMode.ENABLED.value, num_retry=num_retry) def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]: for id_ in ids_values: diff --git a/lerobot/common/motors/feetech/feetech.py b/lerobot/common/motors/feetech/feetech.py index 8c31401b4..bcf549724 100644 --- a/lerobot/common/motors/feetech/feetech.py +++ b/lerobot/common/motors/feetech/feetech.py @@ -164,10 +164,14 @@ class FeetechMotorsBus(MotorsBus): self._assert_same_firmware() def configure_motors(self) -> None: - # By default, Feetech motors have a 500µs delay response time (corresponding to a value of 250 on the - # 'Return_Delay' address). We ensure this is reduced to the minimum of 2µs (value of 0). - for id_ in self.ids: - self.write("Return_Delay_Time", id_, 0) + for motor in self.motors: + # By default, Feetech motors have a 500µs delay response time (corresponding to a value of 250 on + # the 'Return_Delay_Time' address). We ensure this is reduced to the minimum of 2µs (value of 0). + self.write("Return_Delay_Time", motor, 0) + # Set 'Maximum_Acceleration' to 254 to speedup acceleration and deceleration of the motors. + # Note: this address is not in the official STS3215 Memory Table + self.write("Maximum_Acceleration", motor, 254) + self.write("Acceleration", motor, 254) def _get_half_turn_homings(self, positions: dict[NameOrID, Value]) -> dict[NameOrID, Value]: """ @@ -182,15 +186,15 @@ class FeetechMotorsBus(MotorsBus): return half_turn_homings - def disable_torque(self, motors: str | list[str] | None = None) -> None: + def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: for name in self._get_motors_list(motors): - self.write("Torque_Enable", name, TorqueMode.DISABLED.value) - self.write("Lock", name, 0) + self.write("Torque_Enable", name, TorqueMode.DISABLED.value, num_retry=num_retry) + self.write("Lock", name, 0, num_retry=num_retry) - def enable_torque(self, motors: str | list[str] | None = None) -> None: + def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: for name in self._get_motors_list(motors): - self.write("Torque_Enable", name, TorqueMode.ENABLED.value) - self.write("Lock", name, 1) + self.write("Torque_Enable", name, TorqueMode.ENABLED.value, num_retry=num_retry) + self.write("Lock", name, 1, num_retry=num_retry) def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]: for id_ in ids_values: diff --git a/lerobot/common/motors/motors_bus.py b/lerobot/common/motors/motors_bus.py index 1ec4a201d..71016fb50 100644 --- a/lerobot/common/motors/motors_bus.py +++ b/lerobot/common/motors/motors_bus.py @@ -456,11 +456,11 @@ class MotorsBus(abc.ABC): pass @abc.abstractmethod - def disable_torque(self, motors: str | list[str] | None = None) -> None: + def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: pass @abc.abstractmethod - def enable_torque(self, motors: str | list[str] | None = None) -> None: + def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: pass def set_timeout(self, timeout_ms: int | None = None): @@ -972,7 +972,7 @@ class MotorsBus(abc.ABC): if disable_torque: self.port_handler.clearPort() self.port_handler.is_using = False - self.disable_torque() + self.disable_torque(num_retry=5) self.port_handler.closePort() logger.debug(f"{self.__class__.__name__} disconnected.") diff --git a/lerobot/common/robots/so100/so100_follower.py b/lerobot/common/robots/so100/so100_follower.py index 13c5739bc..419b40470 100644 --- a/lerobot/common/robots/so100/so100_follower.py +++ b/lerobot/common/robots/so100/so100_follower.py @@ -55,6 +55,7 @@ class SO100Follower(Robot): "wrist_roll": Motor(5, "sts3215", MotorNormMode.RANGE_M100_100), "gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100), }, + calibration=self.calibration, ) self.cameras = make_cameras_from_configs(config.cameras) @@ -152,10 +153,6 @@ class SO100Follower(Robot): # Set I_Coefficient and D_Coefficient to default value 0 and 32 self.arm.write("I_Coefficient", name, 0) self.arm.write("D_Coefficient", name, 32) - # Set Maximum_Acceleration to 254 to speedup acceleration and deceleration of - # the motors. Note: this address is not in the official STS3215 Memory Table - self.arm.write("Maximum_Acceleration", name, 254) - self.arm.write("Acceleration", name, 254) self.arm.enable_torque() diff --git a/lerobot/common/teleoperators/so100/so100_leader.py b/lerobot/common/teleoperators/so100/so100_leader.py index f8f7239e8..0ed5eafc8 100644 --- a/lerobot/common/teleoperators/so100/so100_leader.py +++ b/lerobot/common/teleoperators/so100/so100_leader.py @@ -51,6 +51,7 @@ class SO100Leader(Teleoperator): "wrist_roll": Motor(5, "sts3215", MotorNormMode.RANGE_M100_100), "gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100), }, + calibration=self.calibration, ) @property From 2bb73ac431f6ee20d236bff1cc07397faa86e64e Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Tue, 15 Apr 2025 11:43:22 +0200 Subject: [PATCH 23/23] Add torque_disabled context --- lerobot/common/motors/motors_bus.py | 9 ++++ lerobot/common/robots/koch/koch_follower.py | 41 +++++++++--------- lerobot/common/robots/so100/so100_follower.py | 20 ++++----- lerobot/common/robots/viperx/viperx.py | 43 +++++++++---------- 4 files changed, 59 insertions(+), 54 deletions(-) diff --git a/lerobot/common/motors/motors_bus.py b/lerobot/common/motors/motors_bus.py index 71016fb50..b70a728c8 100644 --- a/lerobot/common/motors/motors_bus.py +++ b/lerobot/common/motors/motors_bus.py @@ -21,6 +21,7 @@ import abc import logging +from contextlib import contextmanager from dataclasses import dataclass from enum import Enum from functools import cached_property @@ -463,6 +464,14 @@ class MotorsBus(abc.ABC): def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: pass + @contextmanager + def torque_disabled(self): + self.disable_torque() + try: + yield + finally: + self.enable_torque() + def set_timeout(self, timeout_ms: int | None = None): timeout_ms = timeout_ms if timeout_ms is not None else self.default_timeout self.port_handler.setPacketTimeoutMillis(timeout_ms) diff --git a/lerobot/common/robots/koch/koch_follower.py b/lerobot/common/robots/koch/koch_follower.py index fc94f0ea9..2395118db 100644 --- a/lerobot/common/robots/koch/koch_follower.py +++ b/lerobot/common/robots/koch/koch_follower.py @@ -146,29 +146,28 @@ class KochFollower(Robot): logger.info(f"Calibration saved to {self.calibration_fpath}") def configure(self) -> None: - self.arm.disable_torque() - self.arm.configure_motors() - # Use 'extended position mode' for all motors except gripper, because in joint mode the servos - # can't rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while - # assembling the arm, you could end up with a servo with a position 0 or 4095 at a crucial - # point - for name in self.arm.names: - if name != "gripper": - self.arm.write("Operating_Mode", name, OperatingMode.EXTENDED_POSITION.value) + with self.arm.torque_disabled(): + self.arm.configure_motors() + # Use 'extended position mode' for all motors except gripper, because in joint mode the servos + # can't rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling + # the arm, you could end up with a servo with a position 0 or 4095 at a crucial point + for name in self.arm.names: + if name != "gripper": + self.arm.write("Operating_Mode", name, OperatingMode.EXTENDED_POSITION.value) - # Use 'position control current based' for gripper to be limited by the limit of the current. - # For the follower gripper, it means it can grasp an object without forcing too much even tho, - # its goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch). - # For the leader gripper, it means we can use it as a physical trigger, since we can force with our finger - # to make it move, and it will move back to its original target position when we release the force. - self.arm.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value) + # Use 'position control current based' for gripper to be limited by the limit of the current. For + # the follower gripper, it means it can grasp an object without forcing too much even tho, its + # goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch). + # For the leader gripper, it means we can use it as a physical trigger, since we can force with + # our finger to make it move, and it will move back to its original target position when we + # release the force. + self.arm.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value) - # Set better PID values to close the gap between recorded states and actions - # TODO(rcadene): Implement an automatic procedure to set optimal PID values for each motor - self.arm.write("Position_P_Gain", "elbow_flex", 1500) - self.arm.write("Position_I_Gain", "elbow_flex", 0) - self.arm.write("Position_D_Gain", "elbow_flex", 600) - self.arm.enable_torque() + # Set better PID values to close the gap between recorded states and actions + # TODO(rcadene): Implement an automatic procedure to set optimal PID values for each motor + self.arm.write("Position_P_Gain", "elbow_flex", 1500) + self.arm.write("Position_I_Gain", "elbow_flex", 0) + self.arm.write("Position_D_Gain", "elbow_flex", 600) def get_observation(self) -> dict[str, Any]: if not self.is_connected: diff --git a/lerobot/common/robots/so100/so100_follower.py b/lerobot/common/robots/so100/so100_follower.py index 419b40470..50361fc9e 100644 --- a/lerobot/common/robots/so100/so100_follower.py +++ b/lerobot/common/robots/so100/so100_follower.py @@ -144,17 +144,15 @@ class SO100Follower(Robot): print("Calibration saved to", self.calibration_fpath) def configure(self) -> None: - self.arm.disable_torque() - self.arm.configure_motors() - for name in self.arm.names: - self.arm.write("Operating_Mode", name, OperatingMode.POSITION.value) - # Set P_Coefficient to lower value to avoid shakiness (Default is 32) - self.arm.write("P_Coefficient", name, 16) - # Set I_Coefficient and D_Coefficient to default value 0 and 32 - self.arm.write("I_Coefficient", name, 0) - self.arm.write("D_Coefficient", name, 32) - - self.arm.enable_torque() + with self.arm.torque_disabled(): + self.arm.configure_motors() + for name in self.arm.names: + self.arm.write("Operating_Mode", name, OperatingMode.POSITION.value) + # Set P_Coefficient to lower value to avoid shakiness (Default is 32) + self.arm.write("P_Coefficient", name, 16) + # Set I_Coefficient and D_Coefficient to default value 0 and 32 + self.arm.write("I_Coefficient", name, 0) + self.arm.write("D_Coefficient", name, 32) def get_observation(self) -> dict[str, Any]: if not self.is_connected: diff --git a/lerobot/common/robots/viperx/viperx.py b/lerobot/common/robots/viperx/viperx.py index 744fbc87f..76287b2d2 100644 --- a/lerobot/common/robots/viperx/viperx.py +++ b/lerobot/common/robots/viperx/viperx.py @@ -141,32 +141,31 @@ class ViperX(Robot): logger.info(f"Calibration saved to {self.calibration_fpath}") def configure(self) -> None: - self.arm.disable_torque() - self.arm.configure_motors() + with self.arm.torque_disabled(): + self.arm.configure_motors() - # Set secondary/shadow ID for shoulder and elbow. These joints have two motors. - # As a result, if only one of them is required to move to a certain position, - # the other will follow. This is to avoid breaking the motors. - self.arm.write("Secondary_ID", "shoulder_shadow", 2) - self.arm.write("Secondary_ID", "elbow_shadow", 4) + # Set secondary/shadow ID for shoulder and elbow. These joints have two motors. + # As a result, if only one of them is required to move to a certain position, + # the other will follow. This is to avoid breaking the motors. + self.arm.write("Secondary_ID", "shoulder_shadow", 2) + self.arm.write("Secondary_ID", "elbow_shadow", 4) - # Set a velocity limit of 131 as advised by Trossen Robotics - # TODO(aliberts): remove as it's actually useless in position control - self.arm.write("Velocity_Limit", 131) + # Set a velocity limit of 131 as advised by Trossen Robotics + # TODO(aliberts): remove as it's actually useless in position control + self.arm.write("Velocity_Limit", 131) - # Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't - # rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm, - # you could end up with a servo with a position 0 or 4095 at a crucial point. See: - # https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11 - for name in self.arm.names: - if name != "gripper": - self.arm.write("Operating_Mode", name, OperatingMode.EXTENDED_POSITION.value) + # Use 'extended position mode' for all motors except gripper, because in joint mode the servos + # can't rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling + # the arm, you could end up with a servo with a position 0 or 4095 at a crucial point. + # See: https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11 + for name in self.arm.names: + if name != "gripper": + self.arm.write("Operating_Mode", name, OperatingMode.EXTENDED_POSITION.value) - # Use 'position control current based' for follower gripper to be limited by the limit of the current. - # It can grasp an object without forcing too much even tho, it's goal position is a complete grasp - # (both gripper fingers are ordered to join and reach a touch). - self.arm.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value) - self.arm.enable_torque() + # Use 'position control current based' for follower gripper to be limited by the limit of the + # current. It can grasp an object without forcing too much even tho, it's goal position is a + # complete grasp (both gripper fingers are ordered to join and reach a touch). + self.arm.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value) def get_observation(self) -> dict[str, Any]: """The returned observations do not have a batch dimension."""