diff --git a/lerobot/scripts/server/async_inference.proto b/lerobot/scripts/server/async_inference.proto new file mode 100644 index 00000000..8eac7ef9 --- /dev/null +++ b/lerobot/scripts/server/async_inference.proto @@ -0,0 +1,60 @@ +// fmt: off +// flake8: noqa +// !/usr/bin/env python + +// Copyright 2024 The HuggingFace Inc. team. +// All rights reserved. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +syntax = "proto3"; + +package async_inference; + +// AsyncInference: from Robot perspective +// Robot send observations to & executes action received from a remote Policy server +service AsyncInference { + // Robot -> Policy to share observations with a remote inference server + // Policy -> Robot to share actions predicted for given observations + rpc SendObservations(stream Observation) returns (Empty); + rpc StreamActions(Empty) returns (stream Action); + rpc SendPolicyInstructions(PolicySetup) returns (Empty); + rpc Ready(Empty) returns (Empty); +} + +enum TransferState { + TRANSFER_UNKNOWN = 0; + TRANSFER_BEGIN = 1; + TRANSFER_MIDDLE = 2; + TRANSFER_END = 3; +} + +// Messages +message Observation { + // sent by Robot, to remote Policy + TransferState transfer_state = 1; + bytes data = 2; +} + +message Action { + // sent by remote Policy, to Robot + TransferState transfer_state = 1; + bytes data = 2; +} + +message PolicySetup { + // sent by Robot to remote server, to init Policy + TransferState transfer_state = 1; + bytes data = 2; +} + +message Empty {} diff --git a/lerobot/scripts/server/async_inference_pb2.py b/lerobot/scripts/server/async_inference_pb2.py new file mode 100644 index 00000000..e2d18d6f --- /dev/null +++ b/lerobot/scripts/server/async_inference_pb2.py @@ -0,0 +1,48 @@ +# fmt: off +# flake8: noqa +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: async_inference.proto +# Protobuf Python Version: 5.29.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 5, + 29, + 0, + '', + 'async_inference.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x61sync_inference.proto\x12\x0f\x61sync_inference\"S\n\x0bObservation\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"N\n\x06\x41\x63tion\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"S\n\x0bPolicySetup\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xa9\x02\n\x0e\x41syncInference\x12J\n\x10SendObservations\x12\x1c.async_inference.Observation\x1a\x16.async_inference.Empty(\x01\x12\x42\n\rStreamActions\x12\x16.async_inference.Empty\x1a\x17.async_inference.Action0\x01\x12N\n\x16SendPolicyInstructions\x12\x1c.async_inference.PolicySetup\x1a\x16.async_inference.Empty\x12\x37\n\x05Ready\x12\x16.async_inference.Empty\x1a\x16.async_inference.Emptyb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'async_inference_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_TRANSFERSTATE']._serialized_start=301 + _globals['_TRANSFERSTATE']._serialized_end=397 + _globals['_OBSERVATION']._serialized_start=42 + _globals['_OBSERVATION']._serialized_end=125 + _globals['_ACTION']._serialized_start=127 + _globals['_ACTION']._serialized_end=205 + _globals['_POLICYSETUP']._serialized_start=207 + _globals['_POLICYSETUP']._serialized_end=290 + _globals['_EMPTY']._serialized_start=292 + _globals['_EMPTY']._serialized_end=299 + _globals['_ASYNCINFERENCE']._serialized_start=400 + _globals['_ASYNCINFERENCE']._serialized_end=697 +# @@protoc_insertion_point(module_scope) diff --git a/lerobot/scripts/server/async_inference_pb2_grpc.py b/lerobot/scripts/server/async_inference_pb2_grpc.py new file mode 100644 index 00000000..b0ab0f50 --- /dev/null +++ b/lerobot/scripts/server/async_inference_pb2_grpc.py @@ -0,0 +1,236 @@ +# fmt: off +# flake8: noqa +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +import warnings + +import async_inference_pb2 as async__inference__pb2 + +GRPC_GENERATED_VERSION = '1.71.0' +GRPC_VERSION = grpc.__version__ +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + raise RuntimeError( + f'The grpc package installed is at version {GRPC_VERSION},' + + f' but the generated code in async_inference_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + ) + + +class AsyncInferenceStub: + """AsyncInference: from Robot perspective + Robot send observations to & executes action received from a remote Policy server + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.SendObservations = channel.stream_unary( + '/async_inference.AsyncInference/SendObservations', + request_serializer=async__inference__pb2.Observation.SerializeToString, + response_deserializer=async__inference__pb2.Empty.FromString, + _registered_method=True) + self.StreamActions = channel.unary_stream( + '/async_inference.AsyncInference/StreamActions', + request_serializer=async__inference__pb2.Empty.SerializeToString, + response_deserializer=async__inference__pb2.Action.FromString, + _registered_method=True) + self.SendPolicyInstructions = channel.unary_unary( + '/async_inference.AsyncInference/SendPolicyInstructions', + request_serializer=async__inference__pb2.PolicySetup.SerializeToString, + response_deserializer=async__inference__pb2.Empty.FromString, + _registered_method=True) + self.Ready = channel.unary_unary( + '/async_inference.AsyncInference/Ready', + request_serializer=async__inference__pb2.Empty.SerializeToString, + response_deserializer=async__inference__pb2.Empty.FromString, + _registered_method=True) + + +class AsyncInferenceServicer: + """AsyncInference: from Robot perspective + Robot send observations to & executes action received from a remote Policy server + """ + + def SendObservations(self, request_iterator, context): + """Robot -> Policy to share observations with a remote inference server + Policy -> Robot to share actions predicted for given observations + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def StreamActions(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendPolicyInstructions(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Ready(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_AsyncInferenceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'SendObservations': grpc.stream_unary_rpc_method_handler( + servicer.SendObservations, + request_deserializer=async__inference__pb2.Observation.FromString, + response_serializer=async__inference__pb2.Empty.SerializeToString, + ), + 'StreamActions': grpc.unary_stream_rpc_method_handler( + servicer.StreamActions, + request_deserializer=async__inference__pb2.Empty.FromString, + response_serializer=async__inference__pb2.Action.SerializeToString, + ), + 'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler( + servicer.SendPolicyInstructions, + request_deserializer=async__inference__pb2.PolicySetup.FromString, + response_serializer=async__inference__pb2.Empty.SerializeToString, + ), + 'Ready': grpc.unary_unary_rpc_method_handler( + servicer.Ready, + request_deserializer=async__inference__pb2.Empty.FromString, + response_serializer=async__inference__pb2.Empty.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'async_inference.AsyncInference', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('async_inference.AsyncInference', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class AsyncInference: + """AsyncInference: from Robot perspective + Robot send observations to & executes action received from a remote Policy server + """ + + @staticmethod + def SendObservations(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_unary( + request_iterator, + target, + '/async_inference.AsyncInference/SendObservations', + async__inference__pb2.Observation.SerializeToString, + async__inference__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def StreamActions(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream( + request, + target, + '/async_inference.AsyncInference/StreamActions', + async__inference__pb2.Empty.SerializeToString, + async__inference__pb2.Action.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def SendPolicyInstructions(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/async_inference.AsyncInference/SendPolicyInstructions', + async__inference__pb2.PolicySetup.SerializeToString, + async__inference__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def Ready(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/async_inference.AsyncInference/Ready', + async__inference__pb2.Empty.SerializeToString, + async__inference__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/lerobot/scripts/server/constants.py b/lerobot/scripts/server/constants.py new file mode 100644 index 00000000..d510c206 --- /dev/null +++ b/lerobot/scripts/server/constants.py @@ -0,0 +1,12 @@ +"""Server/Client side: Sometimes you just want the environment to wait a tiny bit""" + +idle_wait = 0.01 + +"""Client side: The environment evolves with a time resolution equal to environment_dt""" +environment_dt = 1 / 30 + +"""Server side: Running inference on (at most) environment_dt""" +inference_latency = environment_dt + +"""Supported policies""" +supported_policies = ["act", "smolvla"] diff --git a/lerobot/scripts/server/helpers.py b/lerobot/scripts/server/helpers.py new file mode 100644 index 00000000..8325c9fc --- /dev/null +++ b/lerobot/scripts/server/helpers.py @@ -0,0 +1,128 @@ +import logging +import logging.handlers +import os +import time +from typing import Any + +import torch + + +def setup_logging(prefix: str, info_bracket: str): + """Sets up logging""" + # Create logs directory if it doesn't exist + os.makedirs("logs", exist_ok=True) + + # Delete any existing prefix_* log files + for old_log_file in os.listdir("logs"): + if old_log_file.startswith(prefix) and old_log_file.endswith(".log"): + try: + os.remove(os.path.join("logs", old_log_file)) + print(f"Deleted old log file: {old_log_file}") + except Exception as e: + print(f"Failed to delete old log file {old_log_file}: {e}") + + # Set up logging with both console and file output + logger = logging.getLogger(prefix) + # Prevent propagation to root logger to avoid duplicate messages + logger.propagate = False + + logger.setLevel(logging.INFO) + + # Console handler + console_handler = logging.StreamHandler() + console_handler.setFormatter( + logging.Formatter( + f"%(asctime)s.%(msecs)03d [{info_bracket}] [%(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + ) + logger.addHandler(console_handler) + + # File handler - creates a new log file for each run + file_handler = logging.handlers.RotatingFileHandler( + f"logs/policy_server_{int(time.time())}.log", + maxBytes=10 * 1024 * 1024, # 10MB + backupCount=5, + ) + file_handler.setFormatter( + logging.Formatter( + f"%(asctime)s.%(msecs)03d [{info_bracket}] [%(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + ) + logger.addHandler(file_handler) + + return logger + + +class TimedData: + def __init__(self, timestamp: float, data: Any, timestep: int): + """Initialize a TimedData object. + + Args: + timestamp: Unix timestamp relative to data's creation. + data: The actual data to wrap a timestamp around. + timestep: The timestep of the data. + """ + self.timestamp = timestamp + self.data = data + self.timestep = timestep + + def get_data(self): + return self.data + + def get_timestamp(self): + return self.timestamp + + def get_timestep(self): + return self.timestep + + +class TimedAction(TimedData): + def __init__(self, timestamp: float, action: torch.Tensor, timestep: int): + super().__init__(timestamp=timestamp, data=action, timestep=timestep) + + def get_action(self): + return self.get_data() + + +class TimedObservation(TimedData): + def __init__( + self, + timestamp: float, + observation: dict[str, torch.Tensor], + timestep: int, + transfer_state: int = 0, + must_go: bool = False, + ): + super().__init__(timestamp=timestamp, data=observation, timestep=timestep) + self.transfer_state = transfer_state + self.must_go = must_go + + def get_observation(self): + return self.get_data() + + +class TinyPolicyConfig: + def __init__( + self, + policy_type: str = "act", + pretrained_name_or_path: str = "fracapuano/act_so100_test", + device: str = "cpu", + ): + self.policy_type = policy_type + self.pretrained_name_or_path = pretrained_name_or_path + self.device = device + + +def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool: + """Check if two observation states are similar, under a tolerance threshold""" + return torch.linalg.norm(obs1_state - obs2_state) < atol + + +def observations_similar(obs1: TimedObservation, obs2: TimedObservation, atol: float = 1) -> bool: + """Check if two observations are similar, under a tolerance threshold""" + obs1_state = obs1.get_observation()["observation.state"] + obs2_state = obs2.get_observation()["observation.state"] + + return _compare_observation_states(obs1_state, obs2_state, atol=atol) diff --git a/lerobot/scripts/server/policy_server.py b/lerobot/scripts/server/policy_server.py new file mode 100644 index 00000000..96a9758e --- /dev/null +++ b/lerobot/scripts/server/policy_server.py @@ -0,0 +1,429 @@ +import itertools +import pickle # nosec +import time +from concurrent import futures +from queue import Queue +from typing import Generator, List, Optional + +import async_inference_pb2 # type: ignore +import async_inference_pb2_grpc # type: ignore +import grpc +import torch +from datasets import load_dataset + +from lerobot.common.policies.factory import get_policy_class +from lerobot.scripts.server.constants import environment_dt, idle_wait, inference_latency, supported_policies +from lerobot.scripts.server.helpers import ( + TimedAction, + TimedObservation, + TinyPolicyConfig, + observations_similar, + setup_logging, +) + + +class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): + prefix = "policy_server" + info_bracket = "SERVER" + logger = setup_logging(prefix, info_bracket) + + def __init__(self): + # Initialize dataset action generator (to debug this first version, will be removed in the future) + self.action_generator = itertools.cycle(self._stream_action_chunks_from_dataset()) + + self._setup_server() + + self.actions_per_chunk = 20 + self.actions_overlap = 10 + + self.running = True + + def _setup_server(self) -> None: + """Flushes server state when new client connects.""" + # only running inference on the latest observation received by the server + self.observation_queue = Queue(maxsize=1) + self._predicted_timesteps = set() + self._predicted_observations = Queue(maxsize=1) + + def Ready(self, request, context): # noqa: N802 + client_id = context.peer() + self.logger.info(f"Client {client_id} connected and ready") + self._setup_server() + + return async_inference_pb2.Empty() + + def SendPolicyInstructions(self, request, context): # noqa: N802 + """Receive policy instructions from the robot client""" + client_id = context.peer() + self.logger.debug(f"Receiving policy instructions from {client_id}") + + policy_specs = pickle.loads(request.data) # nosec + assert isinstance(policy_specs, TinyPolicyConfig), ( + f"Policy specs must be a TinyPolicyConfig. Got {type(policy_specs)}" + ) + + self.logger.info( + f"Policy type: {policy_specs.policy_type} | " + f"Pretrained name or path: {policy_specs.pretrained_name_or_path} | " + f"Device: {policy_specs.device}" + ) + + assert policy_specs.policy_type in supported_policies, ( + f"Policy type {policy_specs.policy_type} not supported. Supported policies: {supported_policies}" + ) + + self.device = policy_specs.device + self.policy_type = policy_specs.policy_type # act, pi0, etc. + + policy_class = get_policy_class(self.policy_type) + + start = time.time() + self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path) + self.policy.to(self.device) + end = time.time() + + self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds") + + return async_inference_pb2.Empty() + + def SendObservations(self, request_iterator, context): # noqa: N802 + """Receive observations from the robot client""" + client_id = context.peer() + self.logger.debug(f"Receiving observations from {client_id}") + + for observation in request_iterator: + receive_time = time.time() + timed_observation = pickle.loads(observation.data) # nosec + deserialize_time = time.time() + + self.logger.debug(f"Received observation #{timed_observation.get_timestep()}") + + if not self._maybe_enqueue_observation(timed_observation): + continue + + queue_time = time.time() + + obs_timestep = timed_observation.get_timestep() + obs_timestamp = timed_observation.get_timestamp() + + self.logger.info( + f"Received observation #{obs_timestep} | " + f"Client timestamp: {obs_timestamp:.6f} | " + f"Server timestamp: {receive_time:.6f} | " + ) + + if not hasattr(self, "previous_obs_timestamp"): + self.previous_obs_timestamp = obs_timestamp + + self.logger.debug( + f"1/DeltaObsT (~frequency): {1 / (1e-6 + obs_timestamp - self.previous_obs_timestamp):.6f} Hz| " + f"Network latency: {receive_time - obs_timestamp:.6f}s | " + f"Deserialization time: {deserialize_time - receive_time:.6f}s | " + f"Queue time: {queue_time - deserialize_time:.6f}s | " + ) + + self.previous_obs_timestamp = obs_timestamp + + return async_inference_pb2.Empty() + + def StreamActions(self, request, context): # noqa: N802 + """Stream actions to the robot client""" + client_id = context.peer() + self.logger.debug(f"Client {client_id} connected for action streaming") + + # Generate action based on the most recent observation and its timestep + try: + obs = self.observation_queue.get() + self.logger.info( + f"Running inference for observation #{obs.get_timestep()} (must_go: {obs.must_go})" + ) + + if obs: + self.last_predicted_obs = obs + self._predicted_timesteps.add(obs.get_timestep()) + start_time = time.time() + action_chunk = self._predict_action_chunk(obs) + # action_chunk = self._read_action_chunk(obs) + inference_time = time.time() - start_time + + start_time = time.time() + action_bytes = pickle.dumps(action_chunk) # nosec + serialize_time = time.time() - start_time + + # Create and return the Action + action = async_inference_pb2.Action(transfer_state=obs.transfer_state, data=action_bytes) + + self.logger.info( + f"Action chunk #{obs.get_timestep()} generated | Inference time: {inference_time:.6f}s |" + ) + + self.logger.debug( + f"Action chunk #{obs.get_timestep()} generated | " + f"Inference time: {inference_time:.6f}s |" + f"Serialize time: {serialize_time:.6f}s |" + f"Total time: {inference_time + serialize_time:.6f}s" + ) + + yield action + else: + self.logger.warning("No observation in queue yet!") + time.sleep(idle_wait) + + except Exception as e: + self.logger.error(f"Error in StreamActions: {e}") + + return async_inference_pb2.Empty() + + def _enqueue_and_go(self, obs: TimedObservation): + # If queue is full, get the old observation to make room + if self.observation_queue.full(): + # pops from queue + _ = self.observation_queue.get_nowait() + self.logger.debug("Observation queue was full, removed oldest observation") + + # Now put the new observation (never blocks as queue is non-full here) + self.observation_queue.put(obs) + return True + + def _obs_sanity_checks(self, obs: TimedObservation, previous_obs: TimedObservation) -> bool: + if obs.get_timestep() in self._predicted_timesteps: + self.logger.debug(f"Skipping observation #{obs.get_timestep()} - Timestep predicted already!") + return False + + elif observations_similar(obs, previous_obs, atol=1): + self.logger.debug( + f"Skipping observation #{obs.get_timestep()} - Observation too similar to last obs predicted!" + ) + return False + + else: + return True + + def _maybe_enqueue_observation(self, obs: TimedObservation) -> bool: + """Enqueue an observation if it must go through processing, otherwise skip it. + Observations not in queue are never run through the policy network""" + + if obs.must_go or not hasattr(self, "last_predicted_obs"): + self.logger.info(f"[MUST GO] Enqueued observation #{obs.get_timestep()} for direct processing!") + return self._enqueue_and_go(obs) + + else: + if self._obs_sanity_checks(obs, self.last_predicted_obs): + return self._enqueue_and_go(obs) + else: + return False + + def _time_action_chunk(self, t_0: float, action_chunk: list[torch.Tensor], i_0: int) -> list[TimedAction]: + """Turn a chunk of actions into a list of TimedAction instances, + with the first action corresponding to t_0 and the rest corresponding to + t_0 + i*environment_dt for i in range(len(action_chunk)) + """ + return [ + TimedAction(t_0 + i * environment_dt, action, i_0 + i) for i, action in enumerate(action_chunk) + ] + + @torch.no_grad() + def _run_act_policy(self, observation: dict[str, torch.Tensor]) -> torch.Tensor: + """Run ACT-like policies""" + start_time = time.time() + + # prepare observation for policy forward pass + batch = self.policy.normalize_inputs(observation) + normalize_time = time.time() + self.logger.debug(f"Observation normalization time: {normalize_time - start_time:.6f}s") + + if self.policy.config.image_features: + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch["observation.images"] = [batch[key] for key in self.policy.config.image_features] + prep_time = time.time() + self.logger.debug(f"Observation image preparation time: {prep_time - normalize_time:.6f}s") + + # forward pass outputs up to policy.config.n_action_steps != actions_per_chunk + actions = self.policy.model(batch)[0][:, : self.actions_per_chunk] + + actions = self.policy.unnormalize_outputs({"action": actions})["action"] + + end_time = time.time() + self.logger.info(f"[ACT] Action chunk generation total time: {end_time - start_time:.6f}s") + + return actions + + @torch.no_grad() + def _run_pi0_policy(self, observation: dict[str, torch.Tensor]) -> torch.Tensor: + """Run PI0-like policies""" + raise NotImplementedError("PI0 policy not implemented yet") + + @torch.no_grad() + def _run_smolvla_policy( + self, observation: dict[str, torch.Tensor], noise: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Run smolvla-like policies""" + observation = self.policy.normalize_inputs(observation) + + images, img_masks = self.policy.prepare_images(observation) + state = self.policy.prepare_state(observation) + lang_tokens, lang_masks = self.policy.prepare_language(observation) + + actions = self.policy.model.sample_actions( + images, img_masks, lang_tokens, lang_masks, state, noise=noise + ) + + # Unpad actions + original_action_dim = self.policy.config.action_feature.shape[0] + actions = actions[:, :, :original_action_dim] + + actions = self.policy.unnormalize_outputs( + {"action": actions, "robot_type": [self.policy.config.robot_type]} + )["action"] + + return actions + + def _get_action_chunk( + self, observation: dict[str, torch.Tensor], policy_type: str = "act" + ) -> torch.Tensor: + """Get an action chunk from the policy""" + if policy_type == "act": + return self._run_act_policy(observation) + elif policy_type == "smolvla": + return self._run_smolvla_policy(observation) + else: + raise ValueError(f"Policy class {policy_type} not supported") + + def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]: + """Predict an action based on the observation""" + """1. Prepare observation""" + start_time = time.time() + + observation = { + "robot_type": [self.policy.config.robot_type], + } + for k, v in observation_t.get_observation().items(): + if isinstance(v, torch.Tensor): # VLAs present natural-language instructions + if "image" in k: + # Add batch dimension first, then reorder to NCHW format, then normalize to [0, 1] + observation[k] = ( + v.unsqueeze(0).permute(0, 3, 1, 2).to(self.device, non_blocking=True) / 255.0 + ) + else: + observation[k] = v.unsqueeze(0).to(self.device, non_blocking=True) + else: + observation[k] = v # textual instructions are passed as a list of strings + + prep_time = time.time() + self.logger.debug(f"Observation preparation time: {prep_time - start_time:.6f}s") + + """2. Get action chunk""" + action_tensor = self._get_action_chunk(observation, self.policy_type) + action_tensor = action_tensor.squeeze(0) + + # Move to CPU before serializing + action_tensor = action_tensor.cpu() + + post_inference_time = time.time() + self.logger.debug(f"Post-inference processing start: {post_inference_time - prep_time:.6f}s") + + if action_tensor.dim() == 1: + # No chunk dimension, so repeat action to create a (dummy) chunk of actions + action_tensor = action_tensor.repeat(self.actions_per_chunk, 1) + + action_chunk = self._time_action_chunk( + observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep() + ) + + chunk_time = time.time() + self.logger.debug(f"Action chunk creation time: {chunk_time - post_inference_time:.6f}s") + time.sleep( + max(0, inference_latency - max(0, chunk_time - start_time)) + ) # sleep to control inference latency + + return action_chunk + + def _stream_action_chunks_from_dataset(self) -> Generator[List[torch.Tensor], None, None]: + """Stream chunks of actions from a prerecorded dataset. + + Returns: + Generator that yields chunks of actions from the dataset + """ + import warnings + + warnings.warn( + "This method is deprecated and will be removed in the future.", DeprecationWarning, stacklevel=2 + ) + + dataset = load_dataset("fracapuano/so100_test", split="train").with_format("torch") + + # 1. Select the action column only, where you will find tensors with 6 elements + actions = dataset["action"] + action_indices = torch.arange(len(actions)) + + # 2. Chunk the iterable of tensors into chunks with 10 elements each + # sending only first element for debugging + indices_chunks = action_indices.unfold( + 0, self.actions_per_chunk, self.actions_per_chunk - self.actions_overlap + ) + + for idx_chunk in indices_chunks: + yield actions[idx_chunk[0] : idx_chunk[-1] + 1, :] + + def _read_action_chunk(self, observation: Optional[TimedObservation] = None) -> list[TimedAction]: + """Dummy function for predicting action chunk given observation. + + Instead of computing actions on-the-fly, this method streams + actions from a prerecorded dataset. + """ + import warnings + + warnings.warn( + "This method is deprecated and will be removed in the future.", DeprecationWarning, stacklevel=2 + ) + + start_time = time.time() + if not observation: + observation = TimedObservation(timestamp=time.time(), observation={}, timestep=0) + + # Get chunk of actions from the generator + actions_chunk = next(self.action_generator) + + # Return a list of TimedActions, with timestamps starting from the observation timestamp + actions_chunk = self._time_action_chunk( + observation.get_timestamp(), actions_chunk, observation.get_timestep() + ) + + chunk_time = time.time() + self.logger.debug(f"Action chunk creation time: {chunk_time - start_time:.6f}s") + + # slow action generation, emulates inference time + time.sleep(max(0, inference_latency - max(0, chunk_time - start_time))) + + return actions_chunk + + def stop(self): + """Stop the server""" + self.running = False + self.logger.info("Server stopping...") + + +def serve(): + port = 8080 + # Create the server instance first + policy_server = PolicyServer() + + # Setup and start gRPC server + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server) + server.add_insecure_port(f"[::]:{port}") + server.start() + policy_server.logger.info(f"PolicyServer started on port {port}") + + try: + # Use the running attribute to control server lifetime + while policy_server.running: + time.sleep(1) # Check every second instead of sleeping indefinitely + + except KeyboardInterrupt: + policy_server.stop() + policy_server.logger.info("Keyboard interrupt received") + + +if __name__ == "__main__": + serve() diff --git a/lerobot/scripts/server/robot_client.py b/lerobot/scripts/server/robot_client.py new file mode 100644 index 00000000..d84a0f70 --- /dev/null +++ b/lerobot/scripts/server/robot_client.py @@ -0,0 +1,608 @@ +import argparse +import os +import pickle # nosec +import threading +import time +from queue import Empty, Queue +from typing import Callable, Optional + +import async_inference_pb2 # type: ignore +import async_inference_pb2_grpc # type: ignore +import grpc +import torch + +from lerobot.common.robot_devices.robots.utils import make_robot +from lerobot.scripts.server.constants import environment_dt, idle_wait +from lerobot.scripts.server.helpers import TimedAction, TimedObservation, TinyPolicyConfig, setup_logging + + +class RobotClient: + prefix = "robot_client" + info_bracket = "CLIENT" + logger = setup_logging(prefix, info_bracket) + + def __init__( + self, + server_address: Optional[str] = None, + policy_type: str = "smolvla", + pretrained_name_or_path: str = "lerobot/smolvla_base", + policy_device: str = "cuda", + chunk_size_threshold: float = 0.5, + robot: str = "so100", + ): + # Use environment variable if server_address is not provided + if server_address is None: + server_address = os.getenv("SERVER_ADDRESS", "localhost:8080") + self.logger.info(f"No server address provided, using default address: {server_address}") + + self.policy_config = TinyPolicyConfig(policy_type, pretrained_name_or_path, policy_device) + self.channel = grpc.insecure_channel(server_address) + self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel) + self.logger.info(f"Initializing client to connect to server at {server_address}") + + self.running = False + self.must_go = True # does the observation qualify for direct processing on the policy server? + + self.latest_action = -1 + self.action_chunk_size = -1 + + self._chunk_size_threshold = chunk_size_threshold + + self.action_queue = Queue() + self.start_barrier = threading.Barrier(2) # 2 threads: action receiver, control loop + + start_time = time.time() + self.robot = make_robot(robot) + self.robot.connect() + + connect_time = time.time() + self.logger.info(f"Robot connection time: {connect_time - start_time:.4f}s") + + time.sleep(idle_wait) # sleep waiting for cameras to activate + self.logger.info("Robot connected and ready") + + def timestamps(self): + """Get the timestamps of the actions in the queue""" + return sorted([action.get_timestep() for action in self.action_queue.queue]) + + def start(self): + """Start the robot client and connect to the policy server""" + try: + # client-server handshake + start_time = time.time() + self.stub.Ready(async_inference_pb2.Empty()) + end_time = time.time() + self.logger.info(f"Connected to policy server in {end_time - start_time:.4f}s") + + # send policy instructions + policy_config_bytes = pickle.dumps(self.policy_config) + policy_setup = async_inference_pb2.PolicySetup( + transfer_state=async_inference_pb2.TRANSFER_BEGIN, data=policy_config_bytes + ) + + self.logger.info("Sending policy instructions to policy server") + self.logger.info( + f"Policy type: {self.policy_config.policy_type} | " + f"Pretrained name or path: {self.policy_config.pretrained_name_or_path} | " + f"Device: {self.policy_config.device}" + ) + + self.stub.SendPolicyInstructions(policy_setup) + + self.running = True + self.available_actions_size = [] + return True + + except grpc.RpcError as e: + self.logger.error(f"Failed to connect to policy server: {e}") + return False + + def stop(self): + """Stop the robot client""" + self.running = False + + self.robot.disconnect() + self.logger.info("Robot disconnected") + + self.channel.close() + self.logger.info("Client stopped, channel closed") + + def send_observation( + self, + obs: TimedObservation, + transfer_state: async_inference_pb2.TransferState = async_inference_pb2.TRANSFER_MIDDLE, + ) -> bool: + """Send observation to the policy server. + Returns True if the observation was sent successfully, False otherwise.""" + if not self.running: + self.logger.warning("Client not running") + return False + + assert isinstance(obs, TimedObservation), "Input observation needs to be a TimedObservation!" + + start_time = time.time() + observation_bytes = pickle.dumps(obs) + serialize_time = time.time() + self.logger.debug(f"Observation serialization time: {serialize_time - start_time:.6f}s") + + observation = async_inference_pb2.Observation(transfer_state=transfer_state, data=observation_bytes) + + try: + send_start = time.time() + _ = self.stub.SendObservations(iter([observation])) + send_end = time.time() + + obs_timestep = obs.get_timestep() + + self.logger.info( + f"Sent observation #{obs_timestep} | " + f"Serialize time: {serialize_time - start_time:.6f}s | " + f"Network time: {send_end - send_start:.6f}s | " + f"Total time: {send_end - start_time:.6f}s" + ) + + self.last_obs_sent_time = send_end + return True + + except grpc.RpcError as e: + self.logger.error(f"Error sending observation #{obs.get_timestep()}: {e}") + return False + + def _validate_action(self, action: TimedAction): + """Received actions are keps only when they have been produced for now or later, never before""" + return not action.get_timestep() <= self.latest_action + + def _inspect_action_queue(self): + queue_size = self.action_queue.qsize() + timestamps = sorted([action.get_timestep() for action in self.action_queue.queue]) + self.logger.debug(f"Queue size: {queue_size}, Queue contents: {timestamps}") + return queue_size, timestamps + + def _update_action_queue(self, actions: list[TimedAction]): + """Update the action queue with new actions, without ever emptying the queue""" + + new_queue = Queue() + for action in actions: + if self._validate_action(action): + new_queue.put(action) + + self.action_queue = new_queue + + def _aggregate_action_queues( + self, + incoming_actions: list[TimedAction], + aggregate_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + ): + """Finds the same timestep actions in the queue and aggregates them using the aggregate_fn""" + # TODO(fracapuano): move outside of the function and make aggregate_fn an always required argument + if not aggregate_fn: + # default aggregate function: take the latest action + def aggregate_fn(x1, x2): + return x2 + + action_intersections: list[torch.Tensor] = [] + current_action_queue = { + action.get_timestep(): action.get_action() for action in self.action_queue.queue + } + + for new_action in incoming_actions: + if new_action.get_timestep() in current_action_queue: + # TODO(fracapuano): There is probably a way to do this with broadcasting of the two action tensors + action_intersections.append( + TimedAction( + timestamp=new_action.get_timestamp(), + action=aggregate_fn( + current_action_queue[new_action.get_timestep()], new_action.get_action() + ), + timestep=new_action.get_timestep(), + ) + ) + else: + action_intersections.append(new_action) + + new_queue = Queue() + for action in action_intersections: + if self._validate_action(action): + new_queue.put(action) + + self.action_queue = new_queue + + def _clear_action_queue(self): + """Clear the existing queue""" + while not self.action_queue.empty(): + try: + self.action_queue.get_nowait() + except Empty: + break + + def _fill_action_queue(self, actions: list[TimedAction]): + """Fill the action queue with incoming valid actions""" + start_time = time.time() + valid_count = 0 + + for action in actions: + if self._validate_action(action): + self.action_queue.put(action) + valid_count += 1 + + end_time = time.time() + self.logger.debug( + f"Queue filled: {valid_count}/{len(actions)} valid actions added in {end_time - start_time:.6f}s" + ) + + def _clear_and_fill_action_queue(self, actions: list[TimedAction]): + self._clear_action_queue() + self._fill_action_queue(actions) + + def receive_actions(self): + """Receive actions from the policy server""" + # Wait at barrier for synchronized start + self.start_barrier.wait() + self.logger.info("Action receiving thread starting") + + while self.running: + try: + # Use StreamActions to get a stream of actions from the server + for actions_chunk in self.stub.StreamActions(async_inference_pb2.Empty()): + receive_time = time.time() + + # Deserialize bytes back into list[TimedAction] + deserialize_start = time.time() + timed_actions = pickle.loads(actions_chunk.data) # nosec + deserialize_end = time.time() + + self.action_chunk_size = max(self.action_chunk_size, len(timed_actions)) + + start_time = time.time() + + self.logger.info(f"Current latest action: {self.latest_action}") + + # Get queue state before changes + old_size, old_timesteps = self._inspect_action_queue() + if not old_timesteps: + old_timesteps = [self.latest_action] # queue was empty + + # Log incoming actions + incoming_timesteps = [a.get_timestep() for a in timed_actions] + + # Calculate network latency if we have matching observations + if len(timed_actions) > 0: + first_action_timestep = timed_actions[0].get_timestep() + server_to_client_latency = receive_time - self.last_obs_sent_time + + self.logger.info( + f"Received action chunk for step #{first_action_timestep} | " + f"Latest action: #{self.latest_action} | " + f"Network latency (server->client): {server_to_client_latency:.6f}s | " + f"Deserialization time: {deserialize_end - deserialize_start:.6f}s" + ) + + # Update action queue + start_time = time.time() + self._update_action_queue(timed_actions) + queue_update_time = time.time() - start_time + + self.must_go = ( + True # after receiving actions, next empty queue triggers must-go processing! + ) + + # Get queue state after changes + new_size, new_timesteps = self._inspect_action_queue() + + self.logger.info( + f"Queue update complete ({queue_update_time:.6f}s) | " + f"Before: {old_size} items | " + f"After: {new_size} items | " + ) + self.logger.info( + f"Latest action: {self.latest_action} | " + f"Old action steps: {old_timesteps[0]}:{old_timesteps[-1]} | " + f"Incoming action steps: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | " + f"Updated action steps: {new_timesteps[0]}:{new_timesteps[-1]}" + ) + + except grpc.RpcError as e: + self.logger.error(f"Error receiving actions: {e}") + # Avoid tight loop on action receiver error + time.sleep(idle_wait) + + def _actions_available(self): + """Check if there are actions available in the queue""" + return not self.action_queue.empty() + + def _get_next_action(self) -> Optional[TimedAction]: + """Get the next action from the queue""" + try: + action = self.action_queue.get_nowait() + return action + + except Empty: + return None + + def _perform_action(self, timed_action: TimedAction): + self.robot.send_action(timed_action.get_action()) + self.latest_action = timed_action.get_timestep() + + self.logger.debug( + f"Ts={timed_action.get_timestamp()} | " + f"Action #{timed_action.get_timestep()} performed | " + f"Queue size: {self.action_queue.qsize()}" + ) + + def execute_actions(self): + """Continuously execute actions from the queue""" + import warnings + + warnings.warn("This method is deprecated! Will be removed soon!", stacklevel=2) + # Wait at barrier for synchronized start + self.start_barrier.wait() + time.sleep(idle_wait) # wait for observation capture to start + + self.logger.info("Action execution thread starting") + + while self.running: + # constantly monitor the size of the action queue + self.available_actions_size.append(self.action_queue.qsize()) + + if self._actions_available(): + timed_action = self._get_next_action() + self._perform_action(timed_action) + + time.sleep(environment_dt) + + else: + self.logger.debug("No action available | Sleeping") + time.sleep(idle_wait) + + def stream_observations(self, get_observation_fn): + """Continuously stream observations to the server""" + import warnings + + warnings.warn("This method is deprecated! Will be removed soon!", stacklevel=2) + + # Wait at barrier for synchronized start + self.start_barrier.wait() + self.logger.info("Observation streaming thread starting") + + while self.running: + try: + # Get serialized observation bytes from the function + start_time = time.time() + observation = get_observation_fn() + obs_capture_time = time.time() - start_time + + self.logger.debug(f"Capturing observation took {obs_capture_time:.6f}s") + + if not hasattr(self, "last_obs_timestamp"): + self.last_obs_timestamp = observation.get_timestamp() + + obs_timestep, obs_timestamp = observation.get_timestep(), observation.get_timestamp() + self.logger.info( + f"Ts={obs_timestamp} | " + f"Captured observation #{obs_timestep} | " + f"1/DeltaTs (~frequency)={1 / (1e-6 + obs_timestamp - self.last_obs_timestamp):.6f}" + ) + + self.last_obs_timestamp = obs_timestamp + + # Set appropriate transfer state + if obs_timestep == 0: + state = async_inference_pb2.TRANSFER_BEGIN + else: + state = async_inference_pb2.TRANSFER_MIDDLE + + time.sleep(environment_dt) + self.send_observation(observation, state) + + except Exception as e: + self.logger.error(f"Error in observation sender: {e}") + time.sleep(idle_wait) + + def control_loop_action(self): + """Reading and performing actions in local queue""" + self.available_actions_size.append(self.action_queue.qsize()) + if self._actions_available(): + # Get action from queue + get_start = time.time() + timed_action = self._get_next_action() + get_end = time.time() - get_start + + self.logger.debug( + f"Popping action from queue to perform took {get_end:.6f}s | " + f"Queue size: {self.action_queue.qsize()}" + ) + + self._perform_action(timed_action) + + def _ready_to_send_observation(self): + """Flags when the client is ready to send an observation""" + return self.action_queue.qsize() / self.action_chunk_size <= self._chunk_size_threshold + + def control_loop_observation(self, get_observation_fn): + try: + # Get serialized observation bytes from the function + start_time = time.time() + observation = get_observation_fn() + obs_capture_time = time.time() - start_time + + # If there are no actions left in the queue, the observation must go through processing! + observation.must_go = self.must_go and self.action_queue.empty() + self.logger.debug(f"QUEUE SIZE: {self.action_queue.qsize()} (Must go: {observation.must_go})") + if observation.must_go: + # must-go flag will be set again after receiving actions + self.must_go = False + + if not hasattr(self, "last_obs_timestamp"): + self.last_obs_timestamp = observation.get_timestamp() + + obs_timestep, obs_timestamp = observation.get_timestep(), observation.get_timestamp() + self.last_obs_timestamp = obs_timestamp + + self.logger.info( + f"Ts={obs_timestamp} | " + f"Captured observation #{obs_timestep} | " + f"1/DeltaTs (~frequency)={1 / (1e-6 + obs_timestamp - self.last_obs_timestamp):.6f}" + ) + + self.logger.debug(f"Capturing observation took {obs_capture_time:.6f}s") + + # Set appropriate transfer state + if obs_timestep == 0: + state = async_inference_pb2.TRANSFER_BEGIN + else: + state = async_inference_pb2.TRANSFER_MIDDLE + + self.send_observation(observation, state) + + except Exception as e: + self.logger.error(f"Error in observation sender: {e}") + + def control_loop(self, get_observation_fn): + """Combined function for executing actions and streaming observations""" + # Wait at barrier for synchronized start + self.start_barrier.wait() + self.logger.info("Control loop thread starting") + + control_loops = 0 + while self.running: + control_loop_start = time.time() + self.control_loop_action() + + """Control loop: (2) Streaming observations to the remote policy server""" + if self._ready_to_send_observation() or control_loops == 0: + self.control_loop_observation(get_observation_fn) + + # Dynamically adjust sleep time to maintain the desired control frequency + time.sleep(max(0, environment_dt - (time.time() - control_loop_start))) + control_loops += 1 + + +def async_client(task_instruction: str, verbose: int = 0): + client = RobotClient() + + if client.start(): + # Function to get observations from the robot + def get_observation(): + observation_content = None + observation_content = client.robot.capture_observation() + + observation_content["task"] = [task_instruction] + + observation = TimedObservation( + timestamp=time.time(), observation=observation_content, timestep=max(client.latest_action, 0) + ) + + return observation + + client.logger.info("Starting all threads...") + + # Create and start action receiver thread + action_receiver_thread = threading.Thread(target=client.receive_actions) + action_receiver_thread.daemon = True + + control_loop_thread = threading.Thread(target=client.control_loop, args=(get_observation,)) + control_loop_thread.daemon = True + + # Start all threads + action_receiver_thread.start() + control_loop_thread.start() + + try: + while client.running: + time.sleep(idle_wait) + + except KeyboardInterrupt: + pass + + finally: + client.stop() + client.logger.info("Client stopped") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Robot client for executing tasks via policy server") + parser.add_argument( + "--task", + type=str, + required=True, + help="Task instruction for the robot to execute (e.g., 'fold my tshirt')", + ) + parser.add_argument("--verbose", type=int, default=0, help="Verbosity level (default: 0)") + parser.add_argument( + "--server-port-address", + type=str, + default="localhost:8080", + help="Server & port address (default: localhost:8080, or SERVER_ADDRESS env var)", + ) + parser.add_argument("--policy-type", type=str, default="smolvla", help="Policy type (default: smolvla)") + parser.add_argument( + "--pretrained-name-or-path", + type=str, + default="lerobot/smolvla_base", + help="Pretrained model name or path (default: lerobot/smolvla_base)", + ) + parser.add_argument( + "--policy-device", type=str, default="cuda", help="Device for policy inference (default: cuda)" + ) + parser.add_argument( + "--chunk-size-threshold", + type=float, + default=0.5, + help="Chunk size threshold (`g` in the paper, default: 0.5)", + ) + parser.add_argument( + "--robot", + type=str, + default="so100", + help="Robot name, as per the `make_robot` function (default: so100)", + ) + + args = parser.parse_args() + + # Create client with parsed arguments + client = RobotClient( + server_address=args.server_address, + policy_type=args.policy_type, + pretrained_name_or_path=args.pretrained_name_or_path, + policy_device=args.policy_device, + chunk_size_threshold=args.chunk_size_threshold, + robot=args.robot, + ) + + if client.start(): + # Function to get observations from the robot + def get_observation(): + observation_content = None + observation_content = client.robot.capture_observation() + + observation_content["task"] = [args.task] + + observation = TimedObservation( + timestamp=time.time(), observation=observation_content, timestep=max(client.latest_action, 0) + ) + + return observation + + client.logger.info("Starting all threads...") + + # Create and start action receiver thread + action_receiver_thread = threading.Thread(target=client.receive_actions) + action_receiver_thread.daemon = True + + control_loop_thread = threading.Thread(target=client.control_loop, args=(get_observation,)) + control_loop_thread.daemon = True + + # Start all threads + action_receiver_thread.start() + control_loop_thread.start() + + try: + while client.running: + time.sleep(idle_wait) + + except KeyboardInterrupt: + pass + + finally: + client.stop() + client.logger.info("Client stopped")