Compare commits

...

7 Commits

Author SHA1 Message Date
fracapuano
241e7076f2 add: async inference stack 2025-06-03 18:03:42 +02:00
Adil Zouitine
0cf864870c [Fix] Unpin torch beyond 2.6.0 & torchcodec beyond 0.2.1 (#1127) 2025-05-28 16:54:20 +02:00
mshukor
1786916a16 Update README.md (#1163) 2025-05-27 11:50:43 +02:00
mshukor
0507ad4f68 Update README.md (#1160) 2025-05-27 11:45:07 +02:00
Ragnar
bed90e3a41 fix: typos and grammar (#1148) 2025-05-25 17:20:45 +02:00
Francesco Capuano
6163daaaa4 Fix: emptying action queue between resets (#1117) 2025-05-22 21:37:21 +02:00
Pepijn
8e2a394442 Add editable -e for feetech install command (#1133) 2025-05-20 18:51:21 +02:00
32 changed files with 1566 additions and 25 deletions

View File

@@ -360,7 +360,7 @@ with profile(
If you want, you can cite this work with:
```bibtex
@misc{cadene2024lerobot,
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Wolf, Thomas},
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascale, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
howpublished = "\url{https://github.com/huggingface/lerobot}",
year = {2024}

View File

@@ -55,7 +55,7 @@ conda install ffmpeg -c conda-forge
Install 🤗 LeRobot:
```bash
cd lerobot && pip install ".[feetech]"
cd lerobot && pip install -e ".[feetech]"
```
## Troubleshooting

View File

@@ -141,7 +141,7 @@ python lerobot/scripts/configure_motor.py \
--ID 1
```
Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
Note: These motors are currently limited. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
Then unplug your motor and plug the second motor and set its ID to 2.
```bash

View File

@@ -61,7 +61,7 @@ conda install ffmpeg -c conda-forge
Install 🤗 LeRobot:
```bash
cd lerobot && pip install ".[feetech]"
cd lerobot && pip install -e ".[feetech]"
```
> [!NOTE]

View File

@@ -106,7 +106,7 @@ def worker_process(queue: queue.Queue, num_threads: int):
class AsyncImageWriter:
"""
This class abstract away the initialisation of processes or/and threads to
save images on disk asynchrounously, which is critical to control a robot and record data
save images on disk asynchronously, which is critical to control a robot and record data
at a high frame rate.
When `num_processes=0`, it creates a threads pool of size `num_threads`.

View File

@@ -944,7 +944,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
def stop_image_writer(self) -> None:
"""
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
remove the image_writer in order for the LeRobotDataset object to be pickleable and parallelized.
remove the image_writer in order for the LeRobotDataset object to be picklable and parallelized.
"""
if self.image_writer is not None:
self.image_writer.stop()

View File

@@ -101,7 +101,7 @@ def decode_video_frames_torchvision(
keyframes_only = False
torchvision.set_video_backend(backend)
if backend == "pyav":
keyframes_only = True # pyav doesnt support accuracte seek
keyframes_only = True # pyav doesn't support accurate seek
# set a video stream reader
# TODO(rcadene): also load audio stream at the same time

View File

@@ -357,7 +357,7 @@ class PI0Policy(PreTrainedPolicy):
if self.config.resize_imgs_with_padding is not None:
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
# Normalize from range [0,1] to [-1,1] as expacted by siglip
# Normalize from range [0,1] to [-1,1] as expected by siglip
img = img * 2.0 - 1.0
bsize = img.shape[0]

View File

@@ -516,7 +516,7 @@ class PI0FAST(nn.Module):
interpolate_like_pi=self.config.interpolate_like_pi,
)
# Normalize from range [0,1] to [-1,1] as expacted by siglip
# Normalize from range [0,1] to [-1,1] as expected by siglip
img = img * 2.0 - 1.0
bsize = img.shape[0]

View File

@@ -243,6 +243,11 @@ def control_loop(
timestamp = 0
start_episode_t = time.perf_counter()
# Controls starts, if policy is given it needs cleaning up
if policy is not None:
policy.reset()
while timestamp < control_time_s:
start_loop_t = time.perf_counter()

View File

@@ -0,0 +1,60 @@
// fmt: off
// flake8: noqa
// !/usr/bin/env python
// Copyright 2024 The HuggingFace Inc. team.
// All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package async_inference;
// AsyncInference: from Robot perspective
// Robot send observations to & executes action received from a remote Policy server
service AsyncInference {
// Robot -> Policy to share observations with a remote inference server
// Policy -> Robot to share actions predicted for given observations
rpc SendObservations(stream Observation) returns (Empty);
rpc StreamActions(Empty) returns (stream Action);
rpc SendPolicyInstructions(PolicySetup) returns (Empty);
rpc Ready(Empty) returns (Empty);
}
enum TransferState {
TRANSFER_UNKNOWN = 0;
TRANSFER_BEGIN = 1;
TRANSFER_MIDDLE = 2;
TRANSFER_END = 3;
}
// Messages
message Observation {
// sent by Robot, to remote Policy
TransferState transfer_state = 1;
bytes data = 2;
}
message Action {
// sent by remote Policy, to Robot
TransferState transfer_state = 1;
bytes data = 2;
}
message PolicySetup {
// sent by Robot to remote server, to init Policy
TransferState transfer_state = 1;
bytes data = 2;
}
message Empty {}

View File

@@ -0,0 +1,48 @@
# fmt: off
# flake8: noqa
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: async_inference.proto
# Protobuf Python Version: 5.29.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
29,
0,
'',
'async_inference.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x61sync_inference.proto\x12\x0f\x61sync_inference\"S\n\x0bObservation\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"N\n\x06\x41\x63tion\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"S\n\x0bPolicySetup\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xa9\x02\n\x0e\x41syncInference\x12J\n\x10SendObservations\x12\x1c.async_inference.Observation\x1a\x16.async_inference.Empty(\x01\x12\x42\n\rStreamActions\x12\x16.async_inference.Empty\x1a\x17.async_inference.Action0\x01\x12N\n\x16SendPolicyInstructions\x12\x1c.async_inference.PolicySetup\x1a\x16.async_inference.Empty\x12\x37\n\x05Ready\x12\x16.async_inference.Empty\x1a\x16.async_inference.Emptyb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'async_inference_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_TRANSFERSTATE']._serialized_start=301
_globals['_TRANSFERSTATE']._serialized_end=397
_globals['_OBSERVATION']._serialized_start=42
_globals['_OBSERVATION']._serialized_end=125
_globals['_ACTION']._serialized_start=127
_globals['_ACTION']._serialized_end=205
_globals['_POLICYSETUP']._serialized_start=207
_globals['_POLICYSETUP']._serialized_end=290
_globals['_EMPTY']._serialized_start=292
_globals['_EMPTY']._serialized_end=299
_globals['_ASYNCINFERENCE']._serialized_start=400
_globals['_ASYNCINFERENCE']._serialized_end=697
# @@protoc_insertion_point(module_scope)

View File

@@ -0,0 +1,236 @@
# fmt: off
# flake8: noqa
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings
import async_inference_pb2 as async__inference__pb2
GRPC_GENERATED_VERSION = '1.71.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False
try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in async_inference_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
)
class AsyncInferenceStub:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.SendObservations = channel.stream_unary(
'/async_inference.AsyncInference/SendObservations',
request_serializer=async__inference__pb2.Observation.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
self.StreamActions = channel.unary_stream(
'/async_inference.AsyncInference/StreamActions',
request_serializer=async__inference__pb2.Empty.SerializeToString,
response_deserializer=async__inference__pb2.Action.FromString,
_registered_method=True)
self.SendPolicyInstructions = channel.unary_unary(
'/async_inference.AsyncInference/SendPolicyInstructions',
request_serializer=async__inference__pb2.PolicySetup.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
self.Ready = channel.unary_unary(
'/async_inference.AsyncInference/Ready',
request_serializer=async__inference__pb2.Empty.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
class AsyncInferenceServicer:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
def SendObservations(self, request_iterator, context):
"""Robot -> Policy to share observations with a remote inference server
Policy -> Robot to share actions predicted for given observations
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def StreamActions(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendPolicyInstructions(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Ready(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_AsyncInferenceServicer_to_server(servicer, server):
rpc_method_handlers = {
'SendObservations': grpc.stream_unary_rpc_method_handler(
servicer.SendObservations,
request_deserializer=async__inference__pb2.Observation.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
'StreamActions': grpc.unary_stream_rpc_method_handler(
servicer.StreamActions,
request_deserializer=async__inference__pb2.Empty.FromString,
response_serializer=async__inference__pb2.Action.SerializeToString,
),
'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler(
servicer.SendPolicyInstructions,
request_deserializer=async__inference__pb2.PolicySetup.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
'Ready': grpc.unary_unary_rpc_method_handler(
servicer.Ready,
request_deserializer=async__inference__pb2.Empty.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'async_inference.AsyncInference', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('async_inference.AsyncInference', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class AsyncInference:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
@staticmethod
def SendObservations(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_unary(
request_iterator,
target,
'/async_inference.AsyncInference/SendObservations',
async__inference__pb2.Observation.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def StreamActions(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_stream(
request,
target,
'/async_inference.AsyncInference/StreamActions',
async__inference__pb2.Empty.SerializeToString,
async__inference__pb2.Action.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendPolicyInstructions(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/SendPolicyInstructions',
async__inference__pb2.PolicySetup.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def Ready(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/Ready',
async__inference__pb2.Empty.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

View File

@@ -0,0 +1,12 @@
"""Server/Client side: Sometimes you just want the environment to wait a tiny bit"""
idle_wait = 0.01
"""Client side: The environment evolves with a time resolution equal to environment_dt"""
environment_dt = 1 / 30
"""Server side: Running inference on (at most) environment_dt"""
inference_latency = environment_dt
"""Supported policies"""
supported_policies = ["act", "smolvla"]

View File

@@ -0,0 +1,128 @@
import logging
import logging.handlers
import os
import time
from typing import Any
import torch
def setup_logging(prefix: str, info_bracket: str):
"""Sets up logging"""
# Create logs directory if it doesn't exist
os.makedirs("logs", exist_ok=True)
# Delete any existing prefix_* log files
for old_log_file in os.listdir("logs"):
if old_log_file.startswith(prefix) and old_log_file.endswith(".log"):
try:
os.remove(os.path.join("logs", old_log_file))
print(f"Deleted old log file: {old_log_file}")
except Exception as e:
print(f"Failed to delete old log file {old_log_file}: {e}")
# Set up logging with both console and file output
logger = logging.getLogger(prefix)
# Prevent propagation to root logger to avoid duplicate messages
logger.propagate = False
logger.setLevel(logging.INFO)
# Console handler
console_handler = logging.StreamHandler()
console_handler.setFormatter(
logging.Formatter(
f"%(asctime)s.%(msecs)03d [{info_bracket}] [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
)
logger.addHandler(console_handler)
# File handler - creates a new log file for each run
file_handler = logging.handlers.RotatingFileHandler(
f"logs/policy_server_{int(time.time())}.log",
maxBytes=10 * 1024 * 1024, # 10MB
backupCount=5,
)
file_handler.setFormatter(
logging.Formatter(
f"%(asctime)s.%(msecs)03d [{info_bracket}] [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
)
logger.addHandler(file_handler)
return logger
class TimedData:
def __init__(self, timestamp: float, data: Any, timestep: int):
"""Initialize a TimedData object.
Args:
timestamp: Unix timestamp relative to data's creation.
data: The actual data to wrap a timestamp around.
timestep: The timestep of the data.
"""
self.timestamp = timestamp
self.data = data
self.timestep = timestep
def get_data(self):
return self.data
def get_timestamp(self):
return self.timestamp
def get_timestep(self):
return self.timestep
class TimedAction(TimedData):
def __init__(self, timestamp: float, action: torch.Tensor, timestep: int):
super().__init__(timestamp=timestamp, data=action, timestep=timestep)
def get_action(self):
return self.get_data()
class TimedObservation(TimedData):
def __init__(
self,
timestamp: float,
observation: dict[str, torch.Tensor],
timestep: int,
transfer_state: int = 0,
must_go: bool = False,
):
super().__init__(timestamp=timestamp, data=observation, timestep=timestep)
self.transfer_state = transfer_state
self.must_go = must_go
def get_observation(self):
return self.get_data()
class TinyPolicyConfig:
def __init__(
self,
policy_type: str = "act",
pretrained_name_or_path: str = "fracapuano/act_so100_test",
device: str = "cpu",
):
self.policy_type = policy_type
self.pretrained_name_or_path = pretrained_name_or_path
self.device = device
def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool:
"""Check if two observation states are similar, under a tolerance threshold"""
return torch.linalg.norm(obs1_state - obs2_state) < atol
def observations_similar(obs1: TimedObservation, obs2: TimedObservation, atol: float = 1) -> bool:
"""Check if two observations are similar, under a tolerance threshold"""
obs1_state = obs1.get_observation()["observation.state"]
obs2_state = obs2.get_observation()["observation.state"]
return _compare_observation_states(obs1_state, obs2_state, atol=atol)

View File

@@ -0,0 +1,429 @@
import itertools
import pickle # nosec
import time
from concurrent import futures
from queue import Queue
from typing import Generator, List, Optional
import async_inference_pb2 # type: ignore
import async_inference_pb2_grpc # type: ignore
import grpc
import torch
from datasets import load_dataset
from lerobot.common.policies.factory import get_policy_class
from lerobot.scripts.server.constants import environment_dt, idle_wait, inference_latency, supported_policies
from lerobot.scripts.server.helpers import (
TimedAction,
TimedObservation,
TinyPolicyConfig,
observations_similar,
setup_logging,
)
class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
prefix = "policy_server"
info_bracket = "SERVER"
logger = setup_logging(prefix, info_bracket)
def __init__(self):
# Initialize dataset action generator (to debug this first version, will be removed in the future)
self.action_generator = itertools.cycle(self._stream_action_chunks_from_dataset())
self._setup_server()
self.actions_per_chunk = 20
self.actions_overlap = 10
self.running = True
def _setup_server(self) -> None:
"""Flushes server state when new client connects."""
# only running inference on the latest observation received by the server
self.observation_queue = Queue(maxsize=1)
self._predicted_timesteps = set()
self._predicted_observations = Queue(maxsize=1)
def Ready(self, request, context): # noqa: N802
client_id = context.peer()
self.logger.info(f"Client {client_id} connected and ready")
self._setup_server()
return async_inference_pb2.Empty()
def SendPolicyInstructions(self, request, context): # noqa: N802
"""Receive policy instructions from the robot client"""
client_id = context.peer()
self.logger.debug(f"Receiving policy instructions from {client_id}")
policy_specs = pickle.loads(request.data) # nosec
assert isinstance(policy_specs, TinyPolicyConfig), (
f"Policy specs must be a TinyPolicyConfig. Got {type(policy_specs)}"
)
self.logger.info(
f"Policy type: {policy_specs.policy_type} | "
f"Pretrained name or path: {policy_specs.pretrained_name_or_path} | "
f"Device: {policy_specs.device}"
)
assert policy_specs.policy_type in supported_policies, (
f"Policy type {policy_specs.policy_type} not supported. Supported policies: {supported_policies}"
)
self.device = policy_specs.device
self.policy_type = policy_specs.policy_type # act, pi0, etc.
policy_class = get_policy_class(self.policy_type)
start = time.time()
self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path)
self.policy.to(self.device)
end = time.time()
self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds")
return async_inference_pb2.Empty()
def SendObservations(self, request_iterator, context): # noqa: N802
"""Receive observations from the robot client"""
client_id = context.peer()
self.logger.debug(f"Receiving observations from {client_id}")
for observation in request_iterator:
receive_time = time.time()
timed_observation = pickle.loads(observation.data) # nosec
deserialize_time = time.time()
self.logger.debug(f"Received observation #{timed_observation.get_timestep()}")
if not self._maybe_enqueue_observation(timed_observation):
continue
queue_time = time.time()
obs_timestep = timed_observation.get_timestep()
obs_timestamp = timed_observation.get_timestamp()
self.logger.info(
f"Received observation #{obs_timestep} | "
f"Client timestamp: {obs_timestamp:.6f} | "
f"Server timestamp: {receive_time:.6f} | "
)
if not hasattr(self, "previous_obs_timestamp"):
self.previous_obs_timestamp = obs_timestamp
self.logger.debug(
f"1/DeltaObsT (~frequency): {1 / (1e-6 + obs_timestamp - self.previous_obs_timestamp):.6f} Hz| "
f"Network latency: {receive_time - obs_timestamp:.6f}s | "
f"Deserialization time: {deserialize_time - receive_time:.6f}s | "
f"Queue time: {queue_time - deserialize_time:.6f}s | "
)
self.previous_obs_timestamp = obs_timestamp
return async_inference_pb2.Empty()
def StreamActions(self, request, context): # noqa: N802
"""Stream actions to the robot client"""
client_id = context.peer()
self.logger.debug(f"Client {client_id} connected for action streaming")
# Generate action based on the most recent observation and its timestep
try:
obs = self.observation_queue.get()
self.logger.info(
f"Running inference for observation #{obs.get_timestep()} (must_go: {obs.must_go})"
)
if obs:
self.last_predicted_obs = obs
self._predicted_timesteps.add(obs.get_timestep())
start_time = time.time()
action_chunk = self._predict_action_chunk(obs)
# action_chunk = self._read_action_chunk(obs)
inference_time = time.time() - start_time
start_time = time.time()
action_bytes = pickle.dumps(action_chunk) # nosec
serialize_time = time.time() - start_time
# Create and return the Action
action = async_inference_pb2.Action(transfer_state=obs.transfer_state, data=action_bytes)
self.logger.info(
f"Action chunk #{obs.get_timestep()} generated | Inference time: {inference_time:.6f}s |"
)
self.logger.debug(
f"Action chunk #{obs.get_timestep()} generated | "
f"Inference time: {inference_time:.6f}s |"
f"Serialize time: {serialize_time:.6f}s |"
f"Total time: {inference_time + serialize_time:.6f}s"
)
yield action
else:
self.logger.warning("No observation in queue yet!")
time.sleep(idle_wait)
except Exception as e:
self.logger.error(f"Error in StreamActions: {e}")
return async_inference_pb2.Empty()
def _enqueue_and_go(self, obs: TimedObservation):
# If queue is full, get the old observation to make room
if self.observation_queue.full():
# pops from queue
_ = self.observation_queue.get_nowait()
self.logger.debug("Observation queue was full, removed oldest observation")
# Now put the new observation (never blocks as queue is non-full here)
self.observation_queue.put(obs)
return True
def _obs_sanity_checks(self, obs: TimedObservation, previous_obs: TimedObservation) -> bool:
if obs.get_timestep() in self._predicted_timesteps:
self.logger.debug(f"Skipping observation #{obs.get_timestep()} - Timestep predicted already!")
return False
elif observations_similar(obs, previous_obs, atol=1):
self.logger.debug(
f"Skipping observation #{obs.get_timestep()} - Observation too similar to last obs predicted!"
)
return False
else:
return True
def _maybe_enqueue_observation(self, obs: TimedObservation) -> bool:
"""Enqueue an observation if it must go through processing, otherwise skip it.
Observations not in queue are never run through the policy network"""
if obs.must_go or not hasattr(self, "last_predicted_obs"):
self.logger.info(f"[MUST GO] Enqueued observation #{obs.get_timestep()} for direct processing!")
return self._enqueue_and_go(obs)
else:
if self._obs_sanity_checks(obs, self.last_predicted_obs):
return self._enqueue_and_go(obs)
else:
return False
def _time_action_chunk(self, t_0: float, action_chunk: list[torch.Tensor], i_0: int) -> list[TimedAction]:
"""Turn a chunk of actions into a list of TimedAction instances,
with the first action corresponding to t_0 and the rest corresponding to
t_0 + i*environment_dt for i in range(len(action_chunk))
"""
return [
TimedAction(t_0 + i * environment_dt, action, i_0 + i) for i, action in enumerate(action_chunk)
]
@torch.no_grad()
def _run_act_policy(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
"""Run ACT-like policies"""
start_time = time.time()
# prepare observation for policy forward pass
batch = self.policy.normalize_inputs(observation)
normalize_time = time.time()
self.logger.debug(f"Observation normalization time: {normalize_time - start_time:.6f}s")
if self.policy.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [batch[key] for key in self.policy.config.image_features]
prep_time = time.time()
self.logger.debug(f"Observation image preparation time: {prep_time - normalize_time:.6f}s")
# forward pass outputs up to policy.config.n_action_steps != actions_per_chunk
actions = self.policy.model(batch)[0][:, : self.actions_per_chunk]
actions = self.policy.unnormalize_outputs({"action": actions})["action"]
end_time = time.time()
self.logger.info(f"[ACT] Action chunk generation total time: {end_time - start_time:.6f}s")
return actions
@torch.no_grad()
def _run_pi0_policy(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
"""Run PI0-like policies"""
raise NotImplementedError("PI0 policy not implemented yet")
@torch.no_grad()
def _run_smolvla_policy(
self, observation: dict[str, torch.Tensor], noise: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Run smolvla-like policies"""
observation = self.policy.normalize_inputs(observation)
images, img_masks = self.policy.prepare_images(observation)
state = self.policy.prepare_state(observation)
lang_tokens, lang_masks = self.policy.prepare_language(observation)
actions = self.policy.model.sample_actions(
images, img_masks, lang_tokens, lang_masks, state, noise=noise
)
# Unpad actions
original_action_dim = self.policy.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
actions = self.policy.unnormalize_outputs(
{"action": actions, "robot_type": [self.policy.config.robot_type]}
)["action"]
return actions
def _get_action_chunk(
self, observation: dict[str, torch.Tensor], policy_type: str = "act"
) -> torch.Tensor:
"""Get an action chunk from the policy"""
if policy_type == "act":
return self._run_act_policy(observation)
elif policy_type == "smolvla":
return self._run_smolvla_policy(observation)
else:
raise ValueError(f"Policy class {policy_type} not supported")
def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]:
"""Predict an action based on the observation"""
"""1. Prepare observation"""
start_time = time.time()
observation = {
"robot_type": [self.policy.config.robot_type],
}
for k, v in observation_t.get_observation().items():
if isinstance(v, torch.Tensor): # VLAs present natural-language instructions
if "image" in k:
# Add batch dimension first, then reorder to NCHW format, then normalize to [0, 1]
observation[k] = (
v.unsqueeze(0).permute(0, 3, 1, 2).to(self.device, non_blocking=True) / 255.0
)
else:
observation[k] = v.unsqueeze(0).to(self.device, non_blocking=True)
else:
observation[k] = v # textual instructions are passed as a list of strings
prep_time = time.time()
self.logger.debug(f"Observation preparation time: {prep_time - start_time:.6f}s")
"""2. Get action chunk"""
action_tensor = self._get_action_chunk(observation, self.policy_type)
action_tensor = action_tensor.squeeze(0)
# Move to CPU before serializing
action_tensor = action_tensor.cpu()
post_inference_time = time.time()
self.logger.debug(f"Post-inference processing start: {post_inference_time - prep_time:.6f}s")
if action_tensor.dim() == 1:
# No chunk dimension, so repeat action to create a (dummy) chunk of actions
action_tensor = action_tensor.repeat(self.actions_per_chunk, 1)
action_chunk = self._time_action_chunk(
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
)
chunk_time = time.time()
self.logger.debug(f"Action chunk creation time: {chunk_time - post_inference_time:.6f}s")
time.sleep(
max(0, inference_latency - max(0, chunk_time - start_time))
) # sleep to control inference latency
return action_chunk
def _stream_action_chunks_from_dataset(self) -> Generator[List[torch.Tensor], None, None]:
"""Stream chunks of actions from a prerecorded dataset.
Returns:
Generator that yields chunks of actions from the dataset
"""
import warnings
warnings.warn(
"This method is deprecated and will be removed in the future.", DeprecationWarning, stacklevel=2
)
dataset = load_dataset("fracapuano/so100_test", split="train").with_format("torch")
# 1. Select the action column only, where you will find tensors with 6 elements
actions = dataset["action"]
action_indices = torch.arange(len(actions))
# 2. Chunk the iterable of tensors into chunks with 10 elements each
# sending only first element for debugging
indices_chunks = action_indices.unfold(
0, self.actions_per_chunk, self.actions_per_chunk - self.actions_overlap
)
for idx_chunk in indices_chunks:
yield actions[idx_chunk[0] : idx_chunk[-1] + 1, :]
def _read_action_chunk(self, observation: Optional[TimedObservation] = None) -> list[TimedAction]:
"""Dummy function for predicting action chunk given observation.
Instead of computing actions on-the-fly, this method streams
actions from a prerecorded dataset.
"""
import warnings
warnings.warn(
"This method is deprecated and will be removed in the future.", DeprecationWarning, stacklevel=2
)
start_time = time.time()
if not observation:
observation = TimedObservation(timestamp=time.time(), observation={}, timestep=0)
# Get chunk of actions from the generator
actions_chunk = next(self.action_generator)
# Return a list of TimedActions, with timestamps starting from the observation timestamp
actions_chunk = self._time_action_chunk(
observation.get_timestamp(), actions_chunk, observation.get_timestep()
)
chunk_time = time.time()
self.logger.debug(f"Action chunk creation time: {chunk_time - start_time:.6f}s")
# slow action generation, emulates inference time
time.sleep(max(0, inference_latency - max(0, chunk_time - start_time)))
return actions_chunk
def stop(self):
"""Stop the server"""
self.running = False
self.logger.info("Server stopping...")
def serve():
port = 8080
# Create the server instance first
policy_server = PolicyServer()
# Setup and start gRPC server
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
server.add_insecure_port(f"[::]:{port}")
server.start()
policy_server.logger.info(f"PolicyServer started on port {port}")
try:
# Use the running attribute to control server lifetime
while policy_server.running:
time.sleep(1) # Check every second instead of sleeping indefinitely
except KeyboardInterrupt:
policy_server.stop()
policy_server.logger.info("Keyboard interrupt received")
if __name__ == "__main__":
serve()

View File

@@ -0,0 +1,608 @@
import argparse
import os
import pickle # nosec
import threading
import time
from queue import Empty, Queue
from typing import Callable, Optional
import async_inference_pb2 # type: ignore
import async_inference_pb2_grpc # type: ignore
import grpc
import torch
from lerobot.common.robot_devices.robots.utils import make_robot
from lerobot.scripts.server.constants import environment_dt, idle_wait
from lerobot.scripts.server.helpers import TimedAction, TimedObservation, TinyPolicyConfig, setup_logging
class RobotClient:
prefix = "robot_client"
info_bracket = "CLIENT"
logger = setup_logging(prefix, info_bracket)
def __init__(
self,
server_address: Optional[str] = None,
policy_type: str = "smolvla",
pretrained_name_or_path: str = "lerobot/smolvla_base",
policy_device: str = "cuda",
chunk_size_threshold: float = 0.5,
robot: str = "so100",
):
# Use environment variable if server_address is not provided
if server_address is None:
server_address = os.getenv("SERVER_ADDRESS", "localhost:8080")
self.logger.info(f"No server address provided, using default address: {server_address}")
self.policy_config = TinyPolicyConfig(policy_type, pretrained_name_or_path, policy_device)
self.channel = grpc.insecure_channel(server_address)
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
self.logger.info(f"Initializing client to connect to server at {server_address}")
self.running = False
self.must_go = True # does the observation qualify for direct processing on the policy server?
self.latest_action = -1
self.action_chunk_size = -1
self._chunk_size_threshold = chunk_size_threshold
self.action_queue = Queue()
self.start_barrier = threading.Barrier(2) # 2 threads: action receiver, control loop
start_time = time.time()
self.robot = make_robot(robot)
self.robot.connect()
connect_time = time.time()
self.logger.info(f"Robot connection time: {connect_time - start_time:.4f}s")
time.sleep(idle_wait) # sleep waiting for cameras to activate
self.logger.info("Robot connected and ready")
def timestamps(self):
"""Get the timestamps of the actions in the queue"""
return sorted([action.get_timestep() for action in self.action_queue.queue])
def start(self):
"""Start the robot client and connect to the policy server"""
try:
# client-server handshake
start_time = time.time()
self.stub.Ready(async_inference_pb2.Empty())
end_time = time.time()
self.logger.info(f"Connected to policy server in {end_time - start_time:.4f}s")
# send policy instructions
policy_config_bytes = pickle.dumps(self.policy_config)
policy_setup = async_inference_pb2.PolicySetup(
transfer_state=async_inference_pb2.TRANSFER_BEGIN, data=policy_config_bytes
)
self.logger.info("Sending policy instructions to policy server")
self.logger.info(
f"Policy type: {self.policy_config.policy_type} | "
f"Pretrained name or path: {self.policy_config.pretrained_name_or_path} | "
f"Device: {self.policy_config.device}"
)
self.stub.SendPolicyInstructions(policy_setup)
self.running = True
self.available_actions_size = []
return True
except grpc.RpcError as e:
self.logger.error(f"Failed to connect to policy server: {e}")
return False
def stop(self):
"""Stop the robot client"""
self.running = False
self.robot.disconnect()
self.logger.info("Robot disconnected")
self.channel.close()
self.logger.info("Client stopped, channel closed")
def send_observation(
self,
obs: TimedObservation,
transfer_state: async_inference_pb2.TransferState = async_inference_pb2.TRANSFER_MIDDLE,
) -> bool:
"""Send observation to the policy server.
Returns True if the observation was sent successfully, False otherwise."""
if not self.running:
self.logger.warning("Client not running")
return False
assert isinstance(obs, TimedObservation), "Input observation needs to be a TimedObservation!"
start_time = time.time()
observation_bytes = pickle.dumps(obs)
serialize_time = time.time()
self.logger.debug(f"Observation serialization time: {serialize_time - start_time:.6f}s")
observation = async_inference_pb2.Observation(transfer_state=transfer_state, data=observation_bytes)
try:
send_start = time.time()
_ = self.stub.SendObservations(iter([observation]))
send_end = time.time()
obs_timestep = obs.get_timestep()
self.logger.info(
f"Sent observation #{obs_timestep} | "
f"Serialize time: {serialize_time - start_time:.6f}s | "
f"Network time: {send_end - send_start:.6f}s | "
f"Total time: {send_end - start_time:.6f}s"
)
self.last_obs_sent_time = send_end
return True
except grpc.RpcError as e:
self.logger.error(f"Error sending observation #{obs.get_timestep()}: {e}")
return False
def _validate_action(self, action: TimedAction):
"""Received actions are keps only when they have been produced for now or later, never before"""
return not action.get_timestep() <= self.latest_action
def _inspect_action_queue(self):
queue_size = self.action_queue.qsize()
timestamps = sorted([action.get_timestep() for action in self.action_queue.queue])
self.logger.debug(f"Queue size: {queue_size}, Queue contents: {timestamps}")
return queue_size, timestamps
def _update_action_queue(self, actions: list[TimedAction]):
"""Update the action queue with new actions, without ever emptying the queue"""
new_queue = Queue()
for action in actions:
if self._validate_action(action):
new_queue.put(action)
self.action_queue = new_queue
def _aggregate_action_queues(
self,
incoming_actions: list[TimedAction],
aggregate_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
):
"""Finds the same timestep actions in the queue and aggregates them using the aggregate_fn"""
# TODO(fracapuano): move outside of the function and make aggregate_fn an always required argument
if not aggregate_fn:
# default aggregate function: take the latest action
def aggregate_fn(x1, x2):
return x2
action_intersections: list[torch.Tensor] = []
current_action_queue = {
action.get_timestep(): action.get_action() for action in self.action_queue.queue
}
for new_action in incoming_actions:
if new_action.get_timestep() in current_action_queue:
# TODO(fracapuano): There is probably a way to do this with broadcasting of the two action tensors
action_intersections.append(
TimedAction(
timestamp=new_action.get_timestamp(),
action=aggregate_fn(
current_action_queue[new_action.get_timestep()], new_action.get_action()
),
timestep=new_action.get_timestep(),
)
)
else:
action_intersections.append(new_action)
new_queue = Queue()
for action in action_intersections:
if self._validate_action(action):
new_queue.put(action)
self.action_queue = new_queue
def _clear_action_queue(self):
"""Clear the existing queue"""
while not self.action_queue.empty():
try:
self.action_queue.get_nowait()
except Empty:
break
def _fill_action_queue(self, actions: list[TimedAction]):
"""Fill the action queue with incoming valid actions"""
start_time = time.time()
valid_count = 0
for action in actions:
if self._validate_action(action):
self.action_queue.put(action)
valid_count += 1
end_time = time.time()
self.logger.debug(
f"Queue filled: {valid_count}/{len(actions)} valid actions added in {end_time - start_time:.6f}s"
)
def _clear_and_fill_action_queue(self, actions: list[TimedAction]):
self._clear_action_queue()
self._fill_action_queue(actions)
def receive_actions(self):
"""Receive actions from the policy server"""
# Wait at barrier for synchronized start
self.start_barrier.wait()
self.logger.info("Action receiving thread starting")
while self.running:
try:
# Use StreamActions to get a stream of actions from the server
for actions_chunk in self.stub.StreamActions(async_inference_pb2.Empty()):
receive_time = time.time()
# Deserialize bytes back into list[TimedAction]
deserialize_start = time.time()
timed_actions = pickle.loads(actions_chunk.data) # nosec
deserialize_end = time.time()
self.action_chunk_size = max(self.action_chunk_size, len(timed_actions))
start_time = time.time()
self.logger.info(f"Current latest action: {self.latest_action}")
# Get queue state before changes
old_size, old_timesteps = self._inspect_action_queue()
if not old_timesteps:
old_timesteps = [self.latest_action] # queue was empty
# Log incoming actions
incoming_timesteps = [a.get_timestep() for a in timed_actions]
# Calculate network latency if we have matching observations
if len(timed_actions) > 0:
first_action_timestep = timed_actions[0].get_timestep()
server_to_client_latency = receive_time - self.last_obs_sent_time
self.logger.info(
f"Received action chunk for step #{first_action_timestep} | "
f"Latest action: #{self.latest_action} | "
f"Network latency (server->client): {server_to_client_latency:.6f}s | "
f"Deserialization time: {deserialize_end - deserialize_start:.6f}s"
)
# Update action queue
start_time = time.time()
self._update_action_queue(timed_actions)
queue_update_time = time.time() - start_time
self.must_go = (
True # after receiving actions, next empty queue triggers must-go processing!
)
# Get queue state after changes
new_size, new_timesteps = self._inspect_action_queue()
self.logger.info(
f"Queue update complete ({queue_update_time:.6f}s) | "
f"Before: {old_size} items | "
f"After: {new_size} items | "
)
self.logger.info(
f"Latest action: {self.latest_action} | "
f"Old action steps: {old_timesteps[0]}:{old_timesteps[-1]} | "
f"Incoming action steps: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | "
f"Updated action steps: {new_timesteps[0]}:{new_timesteps[-1]}"
)
except grpc.RpcError as e:
self.logger.error(f"Error receiving actions: {e}")
# Avoid tight loop on action receiver error
time.sleep(idle_wait)
def _actions_available(self):
"""Check if there are actions available in the queue"""
return not self.action_queue.empty()
def _get_next_action(self) -> Optional[TimedAction]:
"""Get the next action from the queue"""
try:
action = self.action_queue.get_nowait()
return action
except Empty:
return None
def _perform_action(self, timed_action: TimedAction):
self.robot.send_action(timed_action.get_action())
self.latest_action = timed_action.get_timestep()
self.logger.debug(
f"Ts={timed_action.get_timestamp()} | "
f"Action #{timed_action.get_timestep()} performed | "
f"Queue size: {self.action_queue.qsize()}"
)
def execute_actions(self):
"""Continuously execute actions from the queue"""
import warnings
warnings.warn("This method is deprecated! Will be removed soon!", stacklevel=2)
# Wait at barrier for synchronized start
self.start_barrier.wait()
time.sleep(idle_wait) # wait for observation capture to start
self.logger.info("Action execution thread starting")
while self.running:
# constantly monitor the size of the action queue
self.available_actions_size.append(self.action_queue.qsize())
if self._actions_available():
timed_action = self._get_next_action()
self._perform_action(timed_action)
time.sleep(environment_dt)
else:
self.logger.debug("No action available | Sleeping")
time.sleep(idle_wait)
def stream_observations(self, get_observation_fn):
"""Continuously stream observations to the server"""
import warnings
warnings.warn("This method is deprecated! Will be removed soon!", stacklevel=2)
# Wait at barrier for synchronized start
self.start_barrier.wait()
self.logger.info("Observation streaming thread starting")
while self.running:
try:
# Get serialized observation bytes from the function
start_time = time.time()
observation = get_observation_fn()
obs_capture_time = time.time() - start_time
self.logger.debug(f"Capturing observation took {obs_capture_time:.6f}s")
if not hasattr(self, "last_obs_timestamp"):
self.last_obs_timestamp = observation.get_timestamp()
obs_timestep, obs_timestamp = observation.get_timestep(), observation.get_timestamp()
self.logger.info(
f"Ts={obs_timestamp} | "
f"Captured observation #{obs_timestep} | "
f"1/DeltaTs (~frequency)={1 / (1e-6 + obs_timestamp - self.last_obs_timestamp):.6f}"
)
self.last_obs_timestamp = obs_timestamp
# Set appropriate transfer state
if obs_timestep == 0:
state = async_inference_pb2.TRANSFER_BEGIN
else:
state = async_inference_pb2.TRANSFER_MIDDLE
time.sleep(environment_dt)
self.send_observation(observation, state)
except Exception as e:
self.logger.error(f"Error in observation sender: {e}")
time.sleep(idle_wait)
def control_loop_action(self):
"""Reading and performing actions in local queue"""
self.available_actions_size.append(self.action_queue.qsize())
if self._actions_available():
# Get action from queue
get_start = time.time()
timed_action = self._get_next_action()
get_end = time.time() - get_start
self.logger.debug(
f"Popping action from queue to perform took {get_end:.6f}s | "
f"Queue size: {self.action_queue.qsize()}"
)
self._perform_action(timed_action)
def _ready_to_send_observation(self):
"""Flags when the client is ready to send an observation"""
return self.action_queue.qsize() / self.action_chunk_size <= self._chunk_size_threshold
def control_loop_observation(self, get_observation_fn):
try:
# Get serialized observation bytes from the function
start_time = time.time()
observation = get_observation_fn()
obs_capture_time = time.time() - start_time
# If there are no actions left in the queue, the observation must go through processing!
observation.must_go = self.must_go and self.action_queue.empty()
self.logger.debug(f"QUEUE SIZE: {self.action_queue.qsize()} (Must go: {observation.must_go})")
if observation.must_go:
# must-go flag will be set again after receiving actions
self.must_go = False
if not hasattr(self, "last_obs_timestamp"):
self.last_obs_timestamp = observation.get_timestamp()
obs_timestep, obs_timestamp = observation.get_timestep(), observation.get_timestamp()
self.last_obs_timestamp = obs_timestamp
self.logger.info(
f"Ts={obs_timestamp} | "
f"Captured observation #{obs_timestep} | "
f"1/DeltaTs (~frequency)={1 / (1e-6 + obs_timestamp - self.last_obs_timestamp):.6f}"
)
self.logger.debug(f"Capturing observation took {obs_capture_time:.6f}s")
# Set appropriate transfer state
if obs_timestep == 0:
state = async_inference_pb2.TRANSFER_BEGIN
else:
state = async_inference_pb2.TRANSFER_MIDDLE
self.send_observation(observation, state)
except Exception as e:
self.logger.error(f"Error in observation sender: {e}")
def control_loop(self, get_observation_fn):
"""Combined function for executing actions and streaming observations"""
# Wait at barrier for synchronized start
self.start_barrier.wait()
self.logger.info("Control loop thread starting")
control_loops = 0
while self.running:
control_loop_start = time.time()
self.control_loop_action()
"""Control loop: (2) Streaming observations to the remote policy server"""
if self._ready_to_send_observation() or control_loops == 0:
self.control_loop_observation(get_observation_fn)
# Dynamically adjust sleep time to maintain the desired control frequency
time.sleep(max(0, environment_dt - (time.time() - control_loop_start)))
control_loops += 1
def async_client(task_instruction: str, verbose: int = 0):
client = RobotClient()
if client.start():
# Function to get observations from the robot
def get_observation():
observation_content = None
observation_content = client.robot.capture_observation()
observation_content["task"] = [task_instruction]
observation = TimedObservation(
timestamp=time.time(), observation=observation_content, timestep=max(client.latest_action, 0)
)
return observation
client.logger.info("Starting all threads...")
# Create and start action receiver thread
action_receiver_thread = threading.Thread(target=client.receive_actions)
action_receiver_thread.daemon = True
control_loop_thread = threading.Thread(target=client.control_loop, args=(get_observation,))
control_loop_thread.daemon = True
# Start all threads
action_receiver_thread.start()
control_loop_thread.start()
try:
while client.running:
time.sleep(idle_wait)
except KeyboardInterrupt:
pass
finally:
client.stop()
client.logger.info("Client stopped")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Robot client for executing tasks via policy server")
parser.add_argument(
"--task",
type=str,
required=True,
help="Task instruction for the robot to execute (e.g., 'fold my tshirt')",
)
parser.add_argument("--verbose", type=int, default=0, help="Verbosity level (default: 0)")
parser.add_argument(
"--server-port-address",
type=str,
default="localhost:8080",
help="Server & port address (default: localhost:8080, or SERVER_ADDRESS env var)",
)
parser.add_argument("--policy-type", type=str, default="smolvla", help="Policy type (default: smolvla)")
parser.add_argument(
"--pretrained-name-or-path",
type=str,
default="lerobot/smolvla_base",
help="Pretrained model name or path (default: lerobot/smolvla_base)",
)
parser.add_argument(
"--policy-device", type=str, default="cuda", help="Device for policy inference (default: cuda)"
)
parser.add_argument(
"--chunk-size-threshold",
type=float,
default=0.5,
help="Chunk size threshold (`g` in the paper, default: 0.5)",
)
parser.add_argument(
"--robot",
type=str,
default="so100",
help="Robot name, as per the `make_robot` function (default: so100)",
)
args = parser.parse_args()
# Create client with parsed arguments
client = RobotClient(
server_address=args.server_address,
policy_type=args.policy_type,
pretrained_name_or_path=args.pretrained_name_or_path,
policy_device=args.policy_device,
chunk_size_threshold=args.chunk_size_threshold,
robot=args.robot,
)
if client.start():
# Function to get observations from the robot
def get_observation():
observation_content = None
observation_content = client.robot.capture_observation()
observation_content["task"] = [args.task]
observation = TimedObservation(
timestamp=time.time(), observation=observation_content, timestep=max(client.latest_action, 0)
)
return observation
client.logger.info("Starting all threads...")
# Create and start action receiver thread
action_receiver_thread = threading.Thread(target=client.receive_actions)
action_receiver_thread.daemon = True
control_loop_thread = threading.Thread(target=client.control_loop, args=(get_observation,))
control_loop_thread.daemon = True
# Start all threads
action_receiver_thread.start()
control_loop_thread.start()
try:
while client.running:
time.sleep(idle_wait)
except KeyboardInterrupt:
pass
finally:
client.stop()
client.logger.info("Client stopped")

View File

@@ -68,8 +68,8 @@ dependencies = [
"pyzmq>=26.2.1",
"rerun-sdk>=0.21.0",
"termcolor>=2.4.0",
"torch>=2.2.1,<2.7",
"torchcodec==0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
"torch>=2.2.1",
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
"torchvision>=0.21.0",
"wandb>=0.16.3",
"zarr>=2.17.0",

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0389a716d51c1c615fb2a3bfa386d89f00b0deca08c4fa21b23e020a939d0213
oid sha256:6b1e600768a8771c5fe650e038a1193597e3810f032041b2a0d021e4496381c1
size 3686488

View File

@@ -28,7 +28,7 @@ from lerobot.common.datasets.transforms import (
from lerobot.common.utils.random_utils import seeded_context
ARTIFACT_DIR = Path("tests/artifacts/image_transforms")
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
DATASET_REPO_ID = "lerobot/aloha_static_cups_open"
def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path):

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0dc691503e7d90b2086bb408e89a65f772ce5ee6e3562ef8c127bcb09bd90851
oid sha256:9d4ebab73eabddc58879a4e770289d19e00a1a4cf2fa5fa33cd3a3246992bc90
size 40551392

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cc67af1d60f95d84c98d6c9ebd648990e0f0705368bd6b72d2b39533950b0179
oid sha256:f3e4c8e85e146b043fd4e4984947c2a6f01627f174a19f18b5914cf690579d77
size 5104

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:64518cf652105d15f5fd2cfc13d0681f66a4ec4797dc5d5dc2f7b0d91fe5dfd6
oid sha256:1a7a8b1a457149109f843c32bcbb047d09de2201847b9b79f7501b447f77ecf4
size 31672

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:32b6d14fab4244b5140adb345e47f662b6739c04974e04b21c3127caa988abbb
oid sha256:5e6ce85296b2009e7c2060d336c0429b1c7197d9adb159e7df0ba18003067b36
size 68

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e1904ef0338f7b6efdec70ec235ee931b5751008bf4eb433edb0b3fa0838a4f1
oid sha256:9b5f557e30aead3731c38cbd85af8c706395d8689a918ad88805b5a886245603
size 33400

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fa544a97f00bf46393a09b006b44c2499bbf7d177782360a8c21cacbf200c07a
oid sha256:2e6625cabfeb4800abc80252cf9112a9271c154edd01eb291658f143c951610b
size 515400

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:83c7a8ae912300b5cedba31904f7ba22542059fd60dd86548a95e415713f719e
oid sha256:224b5fa4828aa88171b68c036e8919c1eae563e2113f03b6461eadf5bf8525a6
size 31672

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5a010633237b3a1141603c65174c551daa9e7b4c474af5a1376d73e5425bfb5d
oid sha256:016d2fa8fe5f58017dfd46f4632fdc19dfd751e32a2c7cde2077c6f95546d6bd
size 68

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ec8b5c440e9fcec190c9be48b28ebb79f82ae63626afe7c811e4bb0c3dd08842
oid sha256:021562ee3e4814425e367ed0c144d6fbe2eb28838247085716cf0b58fd69a075
size 33400

View File

@@ -16,6 +16,7 @@
import pytest
import torch
from packaging import version
from safetensors.torch import load_file
from torchvision.transforms import v2
from torchvision.transforms.v2 import functional as F # noqa: N812
@@ -253,7 +254,14 @@ def test_backward_compatibility_single_transforms(
@require_x86_64_kernel
@pytest.mark.skipif(
version.parse(torch.__version__) < version.parse("2.7.0"),
reason="Test artifacts were generated with PyTorch >= 2.7.0 which has different multinomial behavior",
)
def test_backward_compatibility_default_config(img_tensor, default_transforms):
# NOTE: PyTorch versions have different randomness, it might break this test.
# See this PR: https://github.com/huggingface/lerobot/pull/1127.
cfg = ImageTransformsConfig(enable=True)
default_tf = ImageTransforms(cfg)

View File

@@ -37,7 +37,6 @@ def test_diffuser_scheduler(optimizer):
"base_lrs": [0.001],
"last_epoch": 1,
"lr_lambdas": [None],
"verbose": False,
}
assert scheduler.state_dict() == expected_state_dict
@@ -56,7 +55,6 @@ def test_vqbet_scheduler(optimizer):
"base_lrs": [0.001],
"last_epoch": 1,
"lr_lambdas": [None],
"verbose": False,
}
assert scheduler.state_dict() == expected_state_dict
@@ -77,7 +75,6 @@ def test_cosine_decay_with_warmup_scheduler(optimizer):
"base_lrs": [0.001],
"last_epoch": 1,
"lr_lambdas": [None],
"verbose": False,
}
assert scheduler.state_dict() == expected_state_dict

View File

@@ -20,6 +20,7 @@ from pathlib import Path
import einops
import pytest
import torch
from packaging import version
from safetensors.torch import load_file
from lerobot import available_policies
@@ -408,7 +409,16 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
4. Check that this test now passes.
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
6. Remember to stage and commit the resulting changes to `tests/artifacts`.
NOTE: If the test does not pass, and you don't change the policy, it is likely that the test artifact
is out of date. For example, some PyTorch versions have different randomness, see this PR:
https://github.com/huggingface/lerobot/pull/1127.
"""
# NOTE: ACT policy has different randomness, after PyTorch 2.7.0
if policy_name == "act" and version.parse(torch.__version__) < version.parse("2.7.0"):
pytest.skip(f"Skipping act policy test with PyTorch {torch.__version__}. Requires PyTorch >= 2.7.0")
ds_name = ds_repo_id.split("/")[-1]
artifact_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}"
saved_output_dict = load_file(artifact_dir / "output_dict.safetensors")