Compare commits

...

13 Commits

Author SHA1 Message Date
Francesco Capuano
2b5fce823f fix: camera and motors modules for mock robots 2025-04-23 20:13:38 +02:00
Francesco Capuano
2cce85b5dd fix: action chunks predicted using policy, and timed to observation used 2025-04-19 14:34:36 +02:00
Francesco Capuano
b2d003e6eb fix: client sends timed objects only, and uses lock to read & write robot status 2025-04-19 14:30:29 +02:00
Francesco Capuano
200ba1feb5 add: precommits ignore proto file 2025-04-19 14:18:01 +02:00
Francesco Capuano
0fc9a4341f fix: separate threads for obs streaming, action receiving & execution + action queue reconciliation 2025-04-17 21:09:58 +02:00
Francesco Capuano
d40e74f371 fix: streams inference process using LIFO on obs 2025-04-17 21:09:04 +02:00
Francesco Capuano
40237f5ea3 fix: ruff, get your hands off compiled files 2025-04-17 20:33:54 +02:00
Francesco Capuano
2bcdb57854 fix: bus ids 2025-04-17 20:02:59 +02:00
Francesco Capuano
e9ca1b612d fix: send obs, receives and queues actions chunk, overwrites queue periodically 2025-04-17 19:50:13 +02:00
Francesco Capuano
169babd621 fix: server predicts multiple actions for a given observation, VLA-like 2025-04-17 19:50:02 +02:00
Francesco Capuano
a9031ee1be add: server computes action, robot's daemon constantly reads it 2025-04-17 19:47:20 +02:00
Francesco Capuano
fc107a2c6e add: robot can send observations 2025-04-17 19:47:11 +02:00
Francesco Capuano
84fabbf4af add: grpc service between robot and remote policy server 2025-04-17 19:47:03 +02:00
17 changed files with 1354 additions and 23 deletions

View File

@@ -0,0 +1 @@
# Common mocks for robot devices and testing

View File

View File

@@ -0,0 +1,101 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cache
import numpy as np
CAP_V4L2 = 200
CAP_DSHOW = 700
CAP_AVFOUNDATION = 1200
CAP_ANY = -1
CAP_PROP_FPS = 5
CAP_PROP_FRAME_WIDTH = 3
CAP_PROP_FRAME_HEIGHT = 4
COLOR_RGB2BGR = 4
COLOR_BGR2RGB = 4
ROTATE_90_COUNTERCLOCKWISE = 2
ROTATE_90_CLOCKWISE = 0
ROTATE_180 = 1
@cache
def _generate_image(width: int, height: int):
return np.random.randint(0, 256, size=(height, width, 3), dtype=np.uint8)
def cvtColor(color_image, color_conversion): # noqa: N802
if color_conversion in [COLOR_RGB2BGR, COLOR_BGR2RGB]:
return color_image[:, :, [2, 1, 0]]
else:
raise NotImplementedError(color_conversion)
def rotate(color_image, rotation):
if rotation is None:
return color_image
elif rotation == ROTATE_90_CLOCKWISE:
return np.rot90(color_image, k=1)
elif rotation == ROTATE_180:
return np.rot90(color_image, k=2)
elif rotation == ROTATE_90_COUNTERCLOCKWISE:
return np.rot90(color_image, k=3)
else:
raise NotImplementedError(rotation)
class VideoCapture:
def __init__(self, *args, **kwargs):
self._mock_dict = {
CAP_PROP_FPS: 30,
CAP_PROP_FRAME_WIDTH: 640,
CAP_PROP_FRAME_HEIGHT: 480,
}
self._is_opened = True
def isOpened(self): # noqa: N802
return self._is_opened
def set(self, propId: int, value: float) -> bool: # noqa: N803
if not self._is_opened:
raise RuntimeError("Camera is not opened")
self._mock_dict[propId] = value
return True
def get(self, propId: int) -> float: # noqa: N803
if not self._is_opened:
raise RuntimeError("Camera is not opened")
value = self._mock_dict[propId]
if value == 0:
if propId == CAP_PROP_FRAME_HEIGHT:
value = 480
elif propId == CAP_PROP_FRAME_WIDTH:
value = 640
return value
def read(self):
if not self._is_opened:
raise RuntimeError("Camera is not opened")
h = self.get(CAP_PROP_FRAME_HEIGHT)
w = self.get(CAP_PROP_FRAME_WIDTH)
ret = True
return ret, _generate_image(width=w, height=h)
def release(self):
self._is_opened = False
def __del__(self):
if self._is_opened:
self.release()

View File

@@ -0,0 +1,148 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import numpy as np
class stream(enum.Enum): # noqa: N801
color = 0
depth = 1
class format(enum.Enum): # noqa: N801
rgb8 = 0
z16 = 1
class config: # noqa: N801
def enable_device(self, device_id: str):
self.device_enabled = device_id
def enable_stream(self, stream_type: stream, width=None, height=None, color_format=None, fps=None):
self.stream_type = stream_type
# Overwrite default values when possible
self.width = 848 if width is None else width
self.height = 480 if height is None else height
self.color_format = format.rgb8 if color_format is None else color_format
self.fps = 30 if fps is None else fps
class RSColorProfile:
def __init__(self, config):
self.config = config
def fps(self):
return self.config.fps
def width(self):
return self.config.width
def height(self):
return self.config.height
class RSColorStream:
def __init__(self, config):
self.config = config
def as_video_stream_profile(self):
return RSColorProfile(self.config)
class RSProfile:
def __init__(self, config):
self.config = config
def get_stream(self, color_format):
del color_format # unused
return RSColorStream(self.config)
class pipeline: # noqa: N801
def __init__(self):
self.started = False
self.config = None
def start(self, config):
self.started = True
self.config = config
return RSProfile(self.config)
def stop(self):
if not self.started:
raise RuntimeError("You need to start the camera before stop.")
self.started = False
self.config = None
def wait_for_frames(self, timeout_ms=50000):
del timeout_ms # unused
return RSFrames(self.config)
class RSFrames:
def __init__(self, config):
self.config = config
def get_color_frame(self):
return RSColorFrame(self.config)
def get_depth_frame(self):
return RSDepthFrame(self.config)
class RSColorFrame:
def __init__(self, config):
self.config = config
def get_data(self):
data = np.ones((self.config.height, self.config.width, 3), dtype=np.uint8)
# Create a difference between rgb and bgr
data[:, :, 0] = 2
return data
class RSDepthFrame:
def __init__(self, config):
self.config = config
def get_data(self):
return np.ones((self.config.height, self.config.width), dtype=np.uint16)
class RSDevice:
def __init__(self):
pass
def get_info(self, camera_info) -> str:
del camera_info # unused
# return fake serial number
return "123456789"
class context: # noqa: N801
def __init__(self):
pass
def query_devices(self):
return [RSDevice()]
class camera_info: # noqa: N801
# fake name
name = "Intel RealSense D435I"
def __init__(self, serial_number):
del serial_number
pass

View File

@@ -0,0 +1 @@
# Mocks for motor modules

View File

@@ -0,0 +1,107 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mocked classes and functions from dynamixel_sdk to allow for continuous integration
and testing code logic that requires hardware and devices (e.g. robot arms, cameras)
Warning: These mocked versions are minimalist. They do not exactly mock every behaviors
from the original classes and functions (e.g. return types might be None instead of boolean).
"""
# from dynamixel_sdk import COMM_SUCCESS
DEFAULT_BAUDRATE = 9_600
COMM_SUCCESS = 0 # tx or rx packet communication success
def convert_to_bytes(value, bytes):
# TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform
# `convert_bytes_to_value`
del bytes # unused
return value
def get_default_motor_values(motor_index):
return {
# Key (int) are from X_SERIES_CONTROL_TABLE
7: motor_index, # ID
8: DEFAULT_BAUDRATE, # Baud_rate
10: 0, # Drive_Mode
64: 0, # Torque_Enable
# Set 2560 since calibration values for Aloha gripper is between start_pos=2499 and end_pos=3144
# For other joints, 2560 will be autocorrected to be in calibration range
132: 2560, # Present_Position
}
class PortHandler:
def __init__(self, port):
self.port = port
# factory default baudrate
self.baudrate = DEFAULT_BAUDRATE
def openPort(self): # noqa: N802
return True
def closePort(self): # noqa: N802
pass
def setPacketTimeoutMillis(self, timeout_ms): # noqa: N802
del timeout_ms # unused
def getBaudRate(self): # noqa: N802
return self.baudrate
def setBaudRate(self, baudrate): # noqa: N802
self.baudrate = baudrate
class PacketHandler:
def __init__(self, protocol_version):
del protocol_version # unused
# Use packet_handler.data to communicate across Read and Write
self.data = {}
class GroupSyncRead:
def __init__(self, port_handler, packet_handler, address, bytes):
self.packet_handler = packet_handler
def addParam(self, motor_index): # noqa: N802
# Initialize motor default values
if motor_index not in self.packet_handler.data:
self.packet_handler.data[motor_index] = get_default_motor_values(motor_index)
def txRxPacket(self): # noqa: N802
return COMM_SUCCESS
def getData(self, index, address, bytes): # noqa: N802
return self.packet_handler.data[index][address]
class GroupSyncWrite:
def __init__(self, port_handler, packet_handler, address, bytes):
self.packet_handler = packet_handler
self.address = address
def addParam(self, index, data): # noqa: N802
# Initialize motor default values
if index not in self.packet_handler.data:
self.packet_handler.data[index] = get_default_motor_values(index)
self.changeParam(index, data)
def txPacket(self): # noqa: N802
return COMM_SUCCESS
def changeParam(self, index, data): # noqa: N802
self.packet_handler.data[index][self.address] = data

View File

@@ -0,0 +1,125 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mocked classes and functions from dynamixel_sdk to allow for continuous integration
and testing code logic that requires hardware and devices (e.g. robot arms, cameras)
Warning: These mocked versions are minimalist. They do not exactly mock every behaviors
from the original classes and functions (e.g. return types might be None instead of boolean).
"""
# from dynamixel_sdk import COMM_SUCCESS
DEFAULT_BAUDRATE = 1_000_000
COMM_SUCCESS = 0 # tx or rx packet communication success
def convert_to_bytes(value, bytes):
# TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform
# `convert_bytes_to_value`
del bytes # unused
return value
def get_default_motor_values(motor_index):
return {
# Key (int) are from SCS_SERIES_CONTROL_TABLE
5: motor_index, # ID
6: DEFAULT_BAUDRATE, # Baud_rate
10: 0, # Drive_Mode
21: 32, # P_Coefficient
22: 32, # D_Coefficient
23: 0, # I_Coefficient
40: 0, # Torque_Enable
41: 254, # Acceleration
31: -2047, # Offset
33: 0, # Mode
55: 1, # Lock
# Set 2560 since calibration values for Aloha gripper is between start_pos=2499 and end_pos=3144
# For other joints, 2560 will be autocorrected to be in calibration range
56: 2560, # Present_Position
58: 0, # Present_Speed
69: 0, # Present_Current
85: 150, # Maximum_Acceleration
}
class PortHandler:
def __init__(self, port):
self.port = port
# factory default baudrate
self.baudrate = DEFAULT_BAUDRATE
self.ser = SerialMock()
def openPort(self): # noqa: N802
return True
def closePort(self): # noqa: N802
pass
def setPacketTimeoutMillis(self, timeout_ms): # noqa: N802
del timeout_ms # unused
def getBaudRate(self): # noqa: N802
return self.baudrate
def setBaudRate(self, baudrate): # noqa: N802
self.baudrate = baudrate
class PacketHandler:
def __init__(self, protocol_version):
del protocol_version # unused
# Use packet_handler.data to communicate across Read and Write
self.data = {}
class GroupSyncRead:
def __init__(self, port_handler, packet_handler, address, bytes):
self.packet_handler = packet_handler
def addParam(self, motor_index): # noqa: N802
# Initialize motor default values
if motor_index not in self.packet_handler.data:
self.packet_handler.data[motor_index] = get_default_motor_values(motor_index)
def txRxPacket(self): # noqa: N802
return COMM_SUCCESS
def getData(self, index, address, bytes): # noqa: N802
return self.packet_handler.data[index][address]
class GroupSyncWrite:
def __init__(self, port_handler, packet_handler, address, bytes):
self.packet_handler = packet_handler
self.address = address
def addParam(self, index, data): # noqa: N802
if index not in self.packet_handler.data:
self.packet_handler.data[index] = get_default_motor_values(index)
self.changeParam(index, data)
def txPacket(self): # noqa: N802
return COMM_SUCCESS
def changeParam(self, index, data): # noqa: N802
self.packet_handler.data[index][self.address] = data
class SerialMock:
def reset_output_buffer(self):
pass
def reset_input_buffer(self):
pass

View File

@@ -48,7 +48,7 @@ def find_cameras(raise_when_empty=True, mock=False) -> list[dict]:
connected to the computer.
"""
if mock:
import tests.cameras.mock_pyrealsense2 as rs
import lerobot.common.mocks.cameras.mock_pyrealsense2 as rs
else:
import pyrealsense2 as rs
@@ -100,7 +100,7 @@ def save_images_from_cameras(
serial_numbers = [cam["serial_number"] for cam in camera_infos]
if mock:
import tests.cameras.mock_cv2 as cv2
import lerobot.common.mocks.cameras.mock_cv2 as cv2
else:
import cv2
@@ -253,7 +253,7 @@ class IntelRealSenseCamera:
self.logs = {}
if self.mock:
import tests.cameras.mock_cv2 as cv2
import lerobot.common.mocks.cameras.mock_cv2 as cv2
else:
import cv2
@@ -287,7 +287,7 @@ class IntelRealSenseCamera:
)
if self.mock:
import tests.cameras.mock_pyrealsense2 as rs
import lerobot.common.mocks.cameras.mock_pyrealsense2 as rs
else:
import pyrealsense2 as rs
@@ -375,7 +375,7 @@ class IntelRealSenseCamera:
)
if self.mock:
import tests.cameras.mock_cv2 as cv2
import lerobot.common.mocks.cameras.mock_cv2 as cv2
else:
import cv2

View File

@@ -80,7 +80,7 @@ def _find_cameras(
possible_camera_ids: list[int | str], raise_when_empty=False, mock=False
) -> list[int | str]:
if mock:
import tests.cameras.mock_cv2 as cv2
import lerobot.common.mocks.cameras.mock_cv2 as cv2
else:
import cv2
@@ -269,7 +269,7 @@ class OpenCVCamera:
self.logs = {}
if self.mock:
import tests.cameras.mock_cv2 as cv2
import lerobot.common.mocks.cameras.mock_cv2 as cv2
else:
import cv2
@@ -286,7 +286,7 @@ class OpenCVCamera:
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
if self.mock:
import tests.cameras.mock_cv2 as cv2
import lerobot.common.mocks.cameras.mock_cv2 as cv2
else:
import cv2
@@ -398,7 +398,7 @@ class OpenCVCamera:
# so we convert the image color from BGR to RGB.
if requested_color_mode == "rgb":
if self.mock:
import tests.cameras.mock_cv2 as cv2
import lerobot.common.mocks.cameras.mock_cv2 as cv2
else:
import cv2

View File

@@ -332,7 +332,7 @@ class DynamixelMotorsBus:
)
if self.mock:
import tests.motors.mock_dynamixel_sdk as dxl
import lerobot.common.mocks.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl
@@ -356,7 +356,7 @@ class DynamixelMotorsBus:
def reconnect(self):
if self.mock:
import tests.motors.mock_dynamixel_sdk as dxl
import lerobot.common.mocks.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl
@@ -646,7 +646,7 @@ class DynamixelMotorsBus:
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
if self.mock:
import tests.motors.mock_dynamixel_sdk as dxl
import lerobot.common.mocks.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl
@@ -691,7 +691,7 @@ class DynamixelMotorsBus:
start_time = time.perf_counter()
if self.mock:
import tests.motors.mock_dynamixel_sdk as dxl
import lerobot.common.mocks.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl
@@ -757,7 +757,7 @@ class DynamixelMotorsBus:
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
if self.mock:
import tests.motors.mock_dynamixel_sdk as dxl
import lerobot.common.mocks.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl
@@ -793,7 +793,7 @@ class DynamixelMotorsBus:
start_time = time.perf_counter()
if self.mock:
import tests.motors.mock_dynamixel_sdk as dxl
import lerobot.common.mocks.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl

View File

@@ -313,7 +313,7 @@ class FeetechMotorsBus:
)
if self.mock:
import tests.motors.mock_scservo_sdk as scs
import lerobot.common.mocks.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs
@@ -337,7 +337,7 @@ class FeetechMotorsBus:
def reconnect(self):
if self.mock:
import tests.motors.mock_scservo_sdk as scs
import lerobot.common.mocks.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs
@@ -664,7 +664,7 @@ class FeetechMotorsBus:
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
if self.mock:
import tests.motors.mock_scservo_sdk as scs
import lerobot.common.mocks.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs
@@ -702,7 +702,7 @@ class FeetechMotorsBus:
def read(self, data_name, motor_names: str | list[str] | None = None):
if self.mock:
import tests.motors.mock_scservo_sdk as scs
import lerobot.common.mocks.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs
@@ -782,7 +782,7 @@ class FeetechMotorsBus:
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
if self.mock:
import tests.motors.mock_scservo_sdk as scs
import lerobot.common.mocks.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs
@@ -818,7 +818,7 @@ class FeetechMotorsBus:
start_time = time.perf_counter()
if self.mock:
import tests.motors.mock_scservo_sdk as scs
import lerobot.common.mocks.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs

View File

@@ -443,7 +443,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem58760431091",
port="/dev/tty.usbmodem58760429101",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
@@ -460,7 +460,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem585A0076891",
port="/dev/tty.usbmodem58760435821",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],

View File

@@ -0,0 +1,53 @@
// fmt: off
// flake8: noqa
// !/usr/bin/env python
// Copyright 2024 The HuggingFace Inc. team.
// All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package async_inference;
// AsyncInference: from Robot perspective
// Robot send observations to & executes action received from a remote Policy server
service AsyncInference {
// Robot -> Policy to share observations with a remote inference server
// Policy -> Robot to share actions predicted for given observations
rpc SendObservations(stream Observation) returns (Empty);
rpc StreamActions(Empty) returns (stream Action);
rpc Ready(Empty) returns (Empty);
}
enum TransferState {
TRANSFER_UNKNOWN = 0;
TRANSFER_BEGIN = 1;
TRANSFER_MIDDLE = 2;
TRANSFER_END = 3;
}
// Messages
message Observation {
// sent by Robot, to remote Policy
TransferState transfer_state = 1;
bytes data = 2;
}
message Action {
// sent by remote Policy, to Robot
TransferState transfer_state = 1;
bytes data = 2;
}
message Empty {}

View File

@@ -0,0 +1,46 @@
# fmt: off
# flake8: noqa
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: async_inference.proto
# Protobuf Python Version: 5.29.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
29,
0,
'',
'async_inference.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x61sync_inference.proto\x12\x0f\x61sync_inference\"S\n\x0bObservation\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"N\n\x06\x41\x63tion\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xd9\x01\n\x0e\x41syncInference\x12J\n\x10SendObservations\x12\x1c.async_inference.Observation\x1a\x16.async_inference.Empty(\x01\x12\x42\n\rStreamActions\x12\x16.async_inference.Empty\x1a\x17.async_inference.Action0\x01\x12\x37\n\x05Ready\x12\x16.async_inference.Empty\x1a\x16.async_inference.Emptyb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'async_inference_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_TRANSFERSTATE']._serialized_start=216
_globals['_TRANSFERSTATE']._serialized_end=312
_globals['_OBSERVATION']._serialized_start=42
_globals['_OBSERVATION']._serialized_end=125
_globals['_ACTION']._serialized_start=127
_globals['_ACTION']._serialized_end=205
_globals['_EMPTY']._serialized_start=207
_globals['_EMPTY']._serialized_end=214
_globals['_ASYNCINFERENCE']._serialized_start=315
_globals['_ASYNCINFERENCE']._serialized_end=532
# @@protoc_insertion_point(module_scope)

View File

@@ -0,0 +1,193 @@
# fmt: off
# flake8: noqa
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings
import async_inference_pb2 as async__inference__pb2
GRPC_GENERATED_VERSION = '1.71.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False
try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in async_inference_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
)
class AsyncInferenceStub:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.SendObservations = channel.stream_unary(
'/async_inference.AsyncInference/SendObservations',
request_serializer=async__inference__pb2.Observation.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
self.StreamActions = channel.unary_stream(
'/async_inference.AsyncInference/StreamActions',
request_serializer=async__inference__pb2.Empty.SerializeToString,
response_deserializer=async__inference__pb2.Action.FromString,
_registered_method=True)
self.Ready = channel.unary_unary(
'/async_inference.AsyncInference/Ready',
request_serializer=async__inference__pb2.Empty.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
class AsyncInferenceServicer:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
def SendObservations(self, request_iterator, context):
"""Robot -> Policy to share observations with a remote inference server
Policy -> Robot to share actions predicted for given observations
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def StreamActions(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Ready(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_AsyncInferenceServicer_to_server(servicer, server):
rpc_method_handlers = {
'SendObservations': grpc.stream_unary_rpc_method_handler(
servicer.SendObservations,
request_deserializer=async__inference__pb2.Observation.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
'StreamActions': grpc.unary_stream_rpc_method_handler(
servicer.StreamActions,
request_deserializer=async__inference__pb2.Empty.FromString,
response_serializer=async__inference__pb2.Action.SerializeToString,
),
'Ready': grpc.unary_unary_rpc_method_handler(
servicer.Ready,
request_deserializer=async__inference__pb2.Empty.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'async_inference.AsyncInference', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('async_inference.AsyncInference', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class AsyncInference:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
@staticmethod
def SendObservations(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_unary(
request_iterator,
target,
'/async_inference.AsyncInference/SendObservations',
async__inference__pb2.Observation.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def StreamActions(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_stream(
request,
target,
'/async_inference.AsyncInference/StreamActions',
async__inference__pb2.Empty.SerializeToString,
async__inference__pb2.Action.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def Ready(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/Ready',
async__inference__pb2.Empty.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

View File

@@ -0,0 +1,199 @@
import itertools
import pickle # nosec
import time
from concurrent import futures
from queue import Queue
from typing import Generator, List, Optional
import async_inference_pb2 # type: ignore
import async_inference_pb2_grpc # type: ignore
import grpc
import torch
from datasets import load_dataset
from lerobot.common.policies.act.modeling_act import ACTPolicy
from lerobot.scripts.server.robot_client import TimedAction, TimedObservation, environment_dt
inference_latency = 1 / 3
idle_wait = 0.1
class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
def __init__(self):
# TODO: Add device specification for policy inference at init
self.device = "mps"
start = time.time()
self.policy = ACTPolicy.from_pretrained("fracapuano/act_so100_test")
self.policy.to(self.device)
end = time.time()
print(f"Time taken to put policy on {self.device}: {end - start} seconds")
# Initialize dataset action generator
self.action_generator = itertools.cycle(self._stream_action_chunks_from_dataset())
self._setup_server()
self.actions_per_chunk = 20
self.actions_overlap = 10
def _setup_server(self) -> None:
"""Flushes server state when new client connects."""
# only running inference on the latest observation received by the server
self.observation_queue = Queue(maxsize=1)
def Ready(self, request, context): # noqa: N802
self._setup_server()
print("Client connected and ready")
return async_inference_pb2.Empty()
def SendObservations(self, request_iterator, context): # noqa: N802
"""Receive observations from the robot client"""
# client_id = context.peer()
# print(f"Receiving observations from {client_id}")
for observation in request_iterator:
timed_observation = pickle.loads(observation.data) # nosec
# If queue is full, get the old observation to make room
if self.observation_queue.full():
# pops from queue
_ = self.observation_queue.get_nowait()
# Now put the new observation (never blocks as queue is non-full here)
self.observation_queue.put(timed_observation)
print("Received observation no: ", timed_observation.get_timestep())
return async_inference_pb2.Empty()
def StreamActions(self, request, context): # noqa: N802
"""Stream actions to the robot client"""
# client_id = context.peer()
# print(f"Client {client_id} connected for action streaming")
# Generate action based on the most recent observation and its timestep
obs = self.observation_queue.get()
print("Running inference for timestep: ", obs.get_timestep())
if obs:
yield self._predict_action_chunk(obs)
else:
print("No observation in queue yet!")
time.sleep(idle_wait)
return async_inference_pb2.Empty()
def _time_action_chunk(self, t_0: float, action_chunk: list[torch.Tensor], i_0: int) -> list[TimedAction]:
"""Turn a chunk of actions into a list of TimedAction instances,
with the first action corresponding to t_0 and the rest corresponding to
t_0 + i*environment_dt for i in range(len(action_chunk))
"""
return [
TimedAction(t_0 + i * environment_dt, action, i_0 + i) for i, action in enumerate(action_chunk)
]
@torch.no_grad()
def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]:
"""Predict an action based on the observation"""
self.policy.eval()
observation = {}
for k, v in observation_t.get_observation().items():
if "image" in k:
observation[k] = v.permute(2, 0, 1).unsqueeze(0).to(self.device)
else:
observation[k] = v.unsqueeze(0).to(self.device)
# Remove batch dimension
action_tensor = self.policy.select_action(observation).squeeze(0)
if action_tensor.dim() == 1:
# No chunk dimension, so repeat action to create a (dummy) chunk of actions
action_tensor = action_tensor.cpu().repeat(self.actions_per_chunk, 1)
action_chunk = self._time_action_chunk(
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
)
action_bytes = pickle.dumps(action_chunk) # nosec
# Create and return the Action message
action = async_inference_pb2.Action(transfer_state=observation_t.transfer_state, data=action_bytes)
time.sleep(inference_latency) # slow action generation, emulates inference time (ACT is very fast)
return action
def _stream_action_chunks_from_dataset(self) -> Generator[List[torch.Tensor], None, None]:
"""Stream chunks of actions from a prerecorded dataset.
Returns:
Generator that yields chunks of actions from the dataset
"""
dataset = load_dataset("fracapuano/so100_test", split="train").with_format("torch")
# 1. Select the action column only, where you will find tensors with 6 elements
actions = dataset["action"]
action_indices = torch.arange(len(actions))
# 2. Chunk the iterable of tensors into chunks with 10 elements each
# sending only first element for debugging
indices_chunks = action_indices.unfold(
0, self.actions_per_chunk, self.actions_per_chunk - self.actions_overlap
)
for idx_chunk in indices_chunks:
yield actions[idx_chunk[0] : idx_chunk[-1] + 1, :]
def _read_action_chunk(self, observation: Optional[TimedObservation] = None):
"""Dummy function for predicting action chunk given observation.
Instead of computing actions on-the-fly, this method streams
actions from a prerecorded dataset.
"""
import warnings
warnings.warn(
"This method is deprecated and will be removed in the future.", DeprecationWarning, stacklevel=2
)
if not observation:
observation = TimedObservation(timestamp=time.time(), observation={}, timestep=0)
transfer_state = 0
else:
transfer_state = observation.transfer_state
# Get chunk of actions from the generator
actions_chunk = next(self.action_generator)
# Return a list of TimedActions, with timestamps starting from the observation timestamp
action_data = self._time_action_chunk(
observation.get_timestamp(), actions_chunk, observation.get_timestep()
)
action_bytes = pickle.dumps(action_data) # nosec
# Create and return the Action message
action = async_inference_pb2.Action(transfer_state=transfer_state, data=action_bytes)
time.sleep(inference_latency) # slow action generation, emulates inference time
return action
def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(PolicyServer(), server)
server.add_insecure_port("[::]:50051")
server.start()
print("PolicyServer started on port 50051")
try:
while True:
time.sleep(86400) # Sleep for a day, or until interrupted
except KeyboardInterrupt:
server.stop(0)
print("Server stopped")
if __name__ == "__main__":
serve()

View File

@@ -0,0 +1,357 @@
import pickle # nosec
import threading
import time
from queue import Empty, Queue
from typing import Any, Optional
import async_inference_pb2 # type: ignore
import async_inference_pb2_grpc # type: ignore
import grpc
import torch
from lerobot.common.robot_devices.robots.utils import make_robot
environment_dt = 1 / 30
idle_wait = 0.1
class TimedData:
def __init__(self, timestamp: float, data: Any, timestep: int):
"""Initialize a TimedData object.
Args:
timestamp: Unix timestamp relative to data's creation.
data: The actual data to wrap a timestamp around.
"""
self.timestamp = timestamp
self.data = data
self.timestep = timestep
def get_data(self):
return self.data
def get_timestamp(self):
return self.timestamp
def get_timestep(self):
return self.timestep
class TimedAction(TimedData):
def __init__(self, timestamp: float, action: torch.Tensor, timestep: int):
super().__init__(timestamp=timestamp, data=action, timestep=timestep)
def get_action(self):
return self.get_data()
class TimedObservation(TimedData):
def __init__(
self, timestamp: float, observation: dict[str, torch.Tensor], timestep: int, transfer_state: int = 0
):
super().__init__(timestamp=timestamp, data=observation, timestep=timestep)
self.transfer_state = transfer_state
def get_observation(self):
return self.get_data()
class RobotClient:
def __init__(
self,
# cfg: RobotConfig,
server_address="localhost:50051",
use_robot=True,
):
self.channel = grpc.insecure_channel(server_address)
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
self.running = False
self.first_observation_sent = False
self.latest_action = 0
self.action_chunk_size = 20
self.action_queue = Queue()
self.start_barrier = threading.Barrier(3)
# Create a lock for robot access
self.robot_lock = threading.Lock()
self.use_robot = use_robot
if self.use_robot:
self.robot = make_robot("so100")
self.robot.connect()
time.sleep(idle_wait) # sleep waiting for cameras to activate
print("Robot connected")
self.robot_reading = True
def timestamps(self):
"""Get the timestamps of the actions in the queue"""
return sorted([action.get_timestep() for action in self.action_queue.queue])
def start(self):
"""Start the robot client and connect to the policy server"""
try:
# client-server handshake
self.stub.Ready(async_inference_pb2.Empty())
print("Connected to policy server")
self.running = True
return True
except grpc.RpcError as e:
print(f"Failed to connect to policy server: {e}")
return False
def stop(self):
"""Stop the robot client"""
self.running = False
if self.use_robot and hasattr(self, "robot"):
self.robot.disconnect()
self.channel.close()
def send_observation(
self,
obs: TimedObservation,
transfer_state: async_inference_pb2.TransferState = async_inference_pb2.TRANSFER_MIDDLE,
) -> bool:
"""Send observation to the policy server.
Returns True if the observation was sent successfully, False otherwise."""
if not self.running:
print("Client not running")
return False
assert isinstance(obs, TimedObservation), "Input observation needs to be a TimedObservation!"
observation_bytes = pickle.dumps(obs)
observation = async_inference_pb2.Observation(transfer_state=transfer_state, data=observation_bytes)
try:
_ = self.stub.SendObservations(iter([observation]))
if transfer_state == async_inference_pb2.TRANSFER_BEGIN:
self.first_observation_sent = True
return True
except grpc.RpcError as e:
print(f"Error sending observation: {e}")
return False
def _validate_action(self, action: TimedAction):
"""Received actions are keps only when they have been produced for now or later, never before"""
return not action.get_timestamp() < self.latest_action
def _validate_action_chunk(self, actions: list[TimedAction]):
assert len(actions) == self.action_chunk_size, (
f"Action batch size must match action chunk!size: {len(actions)} != {self.action_chunk_size}"
)
assert all(self._validate_action(action) for action in actions), "Invalid action in chunk"
return True
def _inspect_action_queue(self):
print("Queue size: ", self.action_queue.qsize())
print("Queue contents: ", sorted([action.get_timestep() for action in self.action_queue.queue]))
def _clear_queue(self):
"""Clear the existing queue"""
while not self.action_queue.empty():
try:
self.action_queue.get_nowait()
except Empty:
break
def _fill_action_queue(self, actions: list[TimedAction]):
"""Fill the action queue with incoming valid actions"""
for action in actions:
if self._validate_action(action):
self.action_queue.put(action)
def _update_action_queue(self, actions: list[TimedAction]):
"""Aggregate incoming actions into the action queue.
Raises NotImplementedError as this is not implemented yet.
Args:
actions: List of TimedAction instances to queue
"""
# TODO: Implement this
raise NotImplementedError("Not implemented")
def _clear_and_fill_action_queue(self, actions: list[TimedAction]):
"""Clear the existing queue and fill it with new actions.
This is a higher-level function that combines clearing and filling operations.
Args:
actions: List of TimedAction instances to queue
"""
print("*** Current latest action: ", self.latest_action, "***")
print("\t**** Current queue content ****: ")
self._inspect_action_queue()
print("\t*** Incoming actions ****: ")
print([a.get_timestep() for a in actions])
self._clear_queue()
self._fill_action_queue(actions)
print("\t*** Queue after clearing and filling ****: ")
self._inspect_action_queue()
def receive_actions(self):
"""Receive actions from the policy server"""
# Wait at barrier for synchronized start
self.start_barrier.wait()
print("Action receiving thread starting")
while self.running:
try:
# Use StreamActions to get a stream of actions from the server
for actions_chunk in self.stub.StreamActions(async_inference_pb2.Empty()):
# Deserialize bytes back into list[TimedAction]
timed_actions = pickle.loads(actions_chunk.data) # nosec
# strategy for queue composition is specified in the method
self._clear_and_fill_action_queue(timed_actions)
except grpc.RpcError as e:
print(f"Error receiving actions: {e}")
time.sleep(idle_wait) # Avoid tight loop on error
def _get_next_action(self) -> Optional[TimedAction]:
"""Get the next action from the queue"""
try:
action = self.action_queue.get_nowait()
return action
except Empty:
return None
def execute_actions(self):
"""Continuously execute actions from the queue"""
# Wait at barrier for synchronized start
self.start_barrier.wait()
print("Action execution thread starting")
while self.running:
# Get the next action from the queue
time.sleep(environment_dt)
timed_action = self._get_next_action()
if timed_action is not None:
# self.latest_action = timed_action.get_timestep()
self.latest_action = timed_action.get_timestamp()
# Convert action to tensor and send to robot
if self.use_robot:
# Acquire lock before accessing the robot
if self.robot_lock.acquire(timeout=1.0): # Wait up to 1 second to acquire the lock
try:
self.robot.send_action(timed_action.get_action())
finally:
# Always release the lock in a finally block to ensure it's released
self.robot_lock.release()
else:
print("Could not acquire robot lock for action execution, retrying next cycle")
else:
# No action available, wait and retry fetching from queue
time.sleep(idle_wait)
def stream_observations(self, get_observation_fn):
"""Continuously stream observations to the server"""
# Wait at barrier for synchronized start
self.start_barrier.wait()
print("Observation streaming thread starting")
first_observation = True
while self.running:
try:
# Get serialized observation bytes from the function
time.sleep(environment_dt)
observation = get_observation_fn()
# Skip if observation is None (couldn't acquire lock)
if observation is None:
continue
# Set appropriate transfer state
if first_observation:
state = async_inference_pb2.TRANSFER_BEGIN
first_observation = False
else:
state = async_inference_pb2.TRANSFER_MIDDLE
self.send_observation(observation, state)
except Exception as e:
print(f"Error in observation sender: {e}")
time.sleep(idle_wait)
def async_client():
# Example of how to use the RobotClient
client = RobotClient()
if client.start():
# Function to generate mock observations
def get_observation():
# Create a counter attribute if it doesn't exist
if not hasattr(get_observation, "counter"):
get_observation.counter = 0
# Acquire lock before accessing the robot
observation_content = None
if client.robot_lock.acquire(timeout=1.0): # Wait up to 1 second to acquire the lock
try:
observation_content = client.robot.capture_observation()
finally:
# Always release the lock in a finally block to ensure it's released
client.robot_lock.release()
else:
print("Could not acquire robot lock for observation capture, skipping this cycle")
return None # Return None to indicate no observation was captured
observation = TimedObservation(
timestamp=time.time(), observation=observation_content, timestep=get_observation.counter
)
# Increment counter for next call
get_observation.counter += 1
return observation
print("Starting all threads...")
# Create and start observation sender thread
obs_thread = threading.Thread(target=client.stream_observations, args=(get_observation,))
obs_thread.daemon = True
# Create and start action receiver thread
action_receiver_thread = threading.Thread(target=client.receive_actions)
action_receiver_thread.daemon = True
# Create action execution thread
action_execution_thread = threading.Thread(target=client.execute_actions)
action_execution_thread.daemon = True
# Start all threads
obs_thread.start()
action_receiver_thread.start()
action_execution_thread.start()
try:
# Main thread just keeps everything alive
while client.running:
time.sleep(idle_wait)
except KeyboardInterrupt:
pass
finally:
client.stop()
print("Client stopped")
if __name__ == "__main__":
async_client()