From 6eeab64f8ac9cec94b18833f8099943023887609 Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Mon, 2 Jun 2025 14:46:56 +0700 Subject: [PATCH] [PORT HIL-SERL] Refactor folders structure | Rebased version (#1178) --- lerobot/common/policies/sac/modeling_sac.py | 2 +- .../transport/services.proto} | 8 +- lerobot/common/transport/services_pb2.py | 46 ++ .../transport/services_pb2_grpc.py} | 90 +-- .../transport/utils.py} | 16 +- .../server => common/utils}/buffer.py | 2 +- .../utils/end_effector_control.py} | 0 lerobot/common/utils/process.py | 53 ++ lerobot/common/utils/queue.py | 35 ++ .../utils.py => common/utils/transition.py} | 58 +- .../scripts/{server => }/find_joint_limits.py | 0 .../{server/actor_server.py => rl/actor.py} | 77 +-- .../{server => rl}/crop_dataset_roi.py | 0 .../scripts/{server => rl}/gym_manipulator.py | 0 .../learner_server.py => rl/learner.py} | 41 +- .../scripts/{server => rl}/learner_service.py | 24 +- lerobot/scripts/server/hilserl_pb2.py | 46 -- lerobot/scripts/server/kinematics.py | 546 ------------------ tests/{server => utils}/test_replay_buffer.py | 2 +- 19 files changed, 270 insertions(+), 776 deletions(-) rename lerobot/{scripts/server/hilserl.proto => common/transport/services.proto} (82%) create mode 100644 lerobot/common/transport/services_pb2.py rename lerobot/{scripts/server/hilserl_pb2_grpc.py => common/transport/services_pb2_grpc.py} (64%) rename lerobot/{scripts/server/network_utils.py => common/transport/utils.py} (88%) rename lerobot/{scripts/server => common/utils}/buffer.py (99%) rename lerobot/{scripts/server/end_effector_control_utils.py => common/utils/end_effector_control.py} (100%) create mode 100644 lerobot/common/utils/process.py create mode 100644 lerobot/common/utils/queue.py rename lerobot/{scripts/server/utils.py => common/utils/transition.py} (68%) rename lerobot/scripts/{server => }/find_joint_limits.py (100%) rename lerobot/scripts/{server/actor_server.py => rl/actor.py} (93%) rename lerobot/scripts/{server => rl}/crop_dataset_roi.py (100%) rename lerobot/scripts/{server => rl}/gym_manipulator.py (100%) rename lerobot/scripts/{server/learner_server.py => rl/learner.py} (98%) rename lerobot/scripts/{server => rl}/learner_service.py (78%) delete mode 100644 lerobot/scripts/server/hilserl_pb2.py delete mode 100644 lerobot/scripts/server/kinematics.py rename tests/{server => utils}/test_replay_buffer.py (99%) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index a15974a9..d23d7f82 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -606,7 +606,7 @@ class SACObservationEncoder(nn.Module): Usage patterns: - Called in select_action() with normalize=True - - Called in learner_server.py's get_observation_features() to pre-compute features for all policy components + - Called in learner.py's get_observation_features() to pre-compute features for all policy components - Called internally by forward() with normalize=False Args: diff --git a/lerobot/scripts/server/hilserl.proto b/lerobot/common/transport/services.proto similarity index 82% rename from lerobot/scripts/server/hilserl.proto rename to lerobot/common/transport/services.proto index 2474bfe9..7e655997 100644 --- a/lerobot/scripts/server/hilserl.proto +++ b/lerobot/common/transport/services.proto @@ -13,9 +13,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +// To generate a classes for transport part (services_pb2.py and services_pb2_grpc.py) use the following command: +// +// python -m grpc_tools.protoc -I . --python_out=. --grpc_python_out=. lerobot/common/transport/services.proto +// +// The command should be launched from the root of the project. + syntax = "proto3"; -package hil_serl; +package transport; // LearnerService: the Actor calls this to push transitions. // The Learner implements this service. diff --git a/lerobot/common/transport/services_pb2.py b/lerobot/common/transport/services_pb2.py new file mode 100644 index 00000000..ba2120fd --- /dev/null +++ b/lerobot/common/transport/services_pb2.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: lerobot/common/transport/services.proto +# Protobuf Python Version: 5.29.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 5, + 29, + 0, + '', + 'lerobot/common/transport/services.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\'lerobot/common/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xcc\x02\n\x0eLearnerService\x12I\n\x16SendInteractionMessage\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'lerobot.common.transport.services_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_TRANSFERSTATE']._serialized_start=305 + _globals['_TRANSFERSTATE']._serialized_end=401 + _globals['_TRANSITION']._serialized_start=54 + _globals['_TRANSITION']._serialized_end=130 + _globals['_PARAMETERS']._serialized_start=132 + _globals['_PARAMETERS']._serialized_end=208 + _globals['_INTERACTIONMESSAGE']._serialized_start=210 + _globals['_INTERACTIONMESSAGE']._serialized_end=294 + _globals['_EMPTY']._serialized_start=296 + _globals['_EMPTY']._serialized_end=303 + _globals['_LEARNERSERVICE']._serialized_start=404 + _globals['_LEARNERSERVICE']._serialized_end=736 +# @@protoc_insertion_point(module_scope) diff --git a/lerobot/scripts/server/hilserl_pb2_grpc.py b/lerobot/common/transport/services_pb2_grpc.py similarity index 64% rename from lerobot/scripts/server/hilserl_pb2_grpc.py rename to lerobot/common/transport/services_pb2_grpc.py index 1fa96e81..2e36f7b6 100644 --- a/lerobot/scripts/server/hilserl_pb2_grpc.py +++ b/lerobot/common/transport/services_pb2_grpc.py @@ -3,9 +3,9 @@ import grpc import warnings -import hilserl_pb2 as hilserl__pb2 +from lerobot.common.transport import services_pb2 as lerobot_dot_common_dot_transport_dot_services__pb2 -GRPC_GENERATED_VERSION = '1.70.0' +GRPC_GENERATED_VERSION = '1.71.0' GRPC_VERSION = grpc.__version__ _version_not_supported = False @@ -18,7 +18,7 @@ except ImportError: if _version_not_supported: raise RuntimeError( f'The grpc package installed is at version {GRPC_VERSION},' - + f' but the generated code in hilserl_pb2_grpc.py depends on' + + f' but the generated code in lerobot/common/transport/services_pb2_grpc.py depends on' + f' grpcio>={GRPC_GENERATED_VERSION}.' + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' @@ -37,29 +37,29 @@ class LearnerServiceStub(object): channel: A grpc.Channel. """ self.SendInteractionMessage = channel.unary_unary( - '/hil_serl.LearnerService/SendInteractionMessage', - request_serializer=hilserl__pb2.InteractionMessage.SerializeToString, - response_deserializer=hilserl__pb2.Empty.FromString, + '/transport.LearnerService/SendInteractionMessage', + request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString, + response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, _registered_method=True) self.StreamParameters = channel.unary_stream( - '/hil_serl.LearnerService/StreamParameters', - request_serializer=hilserl__pb2.Empty.SerializeToString, - response_deserializer=hilserl__pb2.Parameters.FromString, + '/transport.LearnerService/StreamParameters', + request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, + response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Parameters.FromString, _registered_method=True) self.SendTransitions = channel.stream_unary( - '/hil_serl.LearnerService/SendTransitions', - request_serializer=hilserl__pb2.Transition.SerializeToString, - response_deserializer=hilserl__pb2.Empty.FromString, + '/transport.LearnerService/SendTransitions', + request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Transition.SerializeToString, + response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, _registered_method=True) self.SendInteractions = channel.stream_unary( - '/hil_serl.LearnerService/SendInteractions', - request_serializer=hilserl__pb2.InteractionMessage.SerializeToString, - response_deserializer=hilserl__pb2.Empty.FromString, + '/transport.LearnerService/SendInteractions', + request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString, + response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, _registered_method=True) self.Ready = channel.unary_unary( - '/hil_serl.LearnerService/Ready', - request_serializer=hilserl__pb2.Empty.SerializeToString, - response_deserializer=hilserl__pb2.Empty.FromString, + '/transport.LearnerService/Ready', + request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, + response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, _registered_method=True) @@ -104,34 +104,34 @@ def add_LearnerServiceServicer_to_server(servicer, server): rpc_method_handlers = { 'SendInteractionMessage': grpc.unary_unary_rpc_method_handler( servicer.SendInteractionMessage, - request_deserializer=hilserl__pb2.InteractionMessage.FromString, - response_serializer=hilserl__pb2.Empty.SerializeToString, + request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.FromString, + response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, ), 'StreamParameters': grpc.unary_stream_rpc_method_handler( servicer.StreamParameters, - request_deserializer=hilserl__pb2.Empty.FromString, - response_serializer=hilserl__pb2.Parameters.SerializeToString, + request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, + response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Parameters.SerializeToString, ), 'SendTransitions': grpc.stream_unary_rpc_method_handler( servicer.SendTransitions, - request_deserializer=hilserl__pb2.Transition.FromString, - response_serializer=hilserl__pb2.Empty.SerializeToString, + request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Transition.FromString, + response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, ), 'SendInteractions': grpc.stream_unary_rpc_method_handler( servicer.SendInteractions, - request_deserializer=hilserl__pb2.InteractionMessage.FromString, - response_serializer=hilserl__pb2.Empty.SerializeToString, + request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.FromString, + response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, ), 'Ready': grpc.unary_unary_rpc_method_handler( servicer.Ready, - request_deserializer=hilserl__pb2.Empty.FromString, - response_serializer=hilserl__pb2.Empty.SerializeToString, + request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, + response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( - 'hil_serl.LearnerService', rpc_method_handlers) + 'transport.LearnerService', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) - server.add_registered_method_handlers('hil_serl.LearnerService', rpc_method_handlers) + server.add_registered_method_handlers('transport.LearnerService', rpc_method_handlers) # This class is part of an EXPERIMENTAL API. @@ -154,9 +154,9 @@ class LearnerService(object): return grpc.experimental.unary_unary( request, target, - '/hil_serl.LearnerService/SendInteractionMessage', - hilserl__pb2.InteractionMessage.SerializeToString, - hilserl__pb2.Empty.FromString, + '/transport.LearnerService/SendInteractionMessage', + lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString, + lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, options, channel_credentials, insecure, @@ -181,9 +181,9 @@ class LearnerService(object): return grpc.experimental.unary_stream( request, target, - '/hil_serl.LearnerService/StreamParameters', - hilserl__pb2.Empty.SerializeToString, - hilserl__pb2.Parameters.FromString, + '/transport.LearnerService/StreamParameters', + lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, + lerobot_dot_common_dot_transport_dot_services__pb2.Parameters.FromString, options, channel_credentials, insecure, @@ -208,9 +208,9 @@ class LearnerService(object): return grpc.experimental.stream_unary( request_iterator, target, - '/hil_serl.LearnerService/SendTransitions', - hilserl__pb2.Transition.SerializeToString, - hilserl__pb2.Empty.FromString, + '/transport.LearnerService/SendTransitions', + lerobot_dot_common_dot_transport_dot_services__pb2.Transition.SerializeToString, + lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, options, channel_credentials, insecure, @@ -235,9 +235,9 @@ class LearnerService(object): return grpc.experimental.stream_unary( request_iterator, target, - '/hil_serl.LearnerService/SendInteractions', - hilserl__pb2.InteractionMessage.SerializeToString, - hilserl__pb2.Empty.FromString, + '/transport.LearnerService/SendInteractions', + lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString, + lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, options, channel_credentials, insecure, @@ -262,9 +262,9 @@ class LearnerService(object): return grpc.experimental.unary_unary( request, target, - '/hil_serl.LearnerService/Ready', - hilserl__pb2.Empty.SerializeToString, - hilserl__pb2.Empty.FromString, + '/transport.LearnerService/Ready', + lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, + lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, options, channel_credentials, insecure, diff --git a/lerobot/scripts/server/network_utils.py b/lerobot/common/transport/utils.py similarity index 88% rename from lerobot/scripts/server/network_utils.py rename to lerobot/common/transport/utils.py index 1b1d8044..6ef1c801 100644 --- a/lerobot/scripts/server/network_utils.py +++ b/lerobot/common/transport/utils.py @@ -23,8 +23,8 @@ from typing import Any import torch -from lerobot.scripts.server import hilserl_pb2 -from lerobot.scripts.server.utils import Transition +from lerobot.common.transport import services_pb2 +from lerobot.common.utils.transition import Transition CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB @@ -47,12 +47,12 @@ def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = "" logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with") while sent_bytes < size_in_bytes: - transfer_state = hilserl_pb2.TransferState.TRANSFER_MIDDLE + transfer_state = services_pb2.TransferState.TRANSFER_MIDDLE if sent_bytes + CHUNK_SIZE >= size_in_bytes: - transfer_state = hilserl_pb2.TransferState.TRANSFER_END + transfer_state = services_pb2.TransferState.TRANSFER_END elif sent_bytes == 0: - transfer_state = hilserl_pb2.TransferState.TRANSFER_BEGIN + transfer_state = services_pb2.TransferState.TRANSFER_BEGIN size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes) chunk = buffer.read(size_to_read) @@ -75,18 +75,18 @@ def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_p logging.info(f"{log_prefix} Shutting down receiver") return - if item.transfer_state == hilserl_pb2.TransferState.TRANSFER_BEGIN: + if item.transfer_state == services_pb2.TransferState.TRANSFER_BEGIN: bytes_buffer.seek(0) bytes_buffer.truncate(0) bytes_buffer.write(item.data) logging.debug(f"{log_prefix} Received data at step 0") step = 0 continue - elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_MIDDLE: + elif item.transfer_state == services_pb2.TransferState.TRANSFER_MIDDLE: bytes_buffer.write(item.data) step += 1 logging.debug(f"{log_prefix} Received data at step {step}") - elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_END: + elif item.transfer_state == services_pb2.TransferState.TRANSFER_END: bytes_buffer.write(item.data) logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}") diff --git a/lerobot/scripts/server/buffer.py b/lerobot/common/utils/buffer.py similarity index 99% rename from lerobot/scripts/server/buffer.py rename to lerobot/common/utils/buffer.py index 45d0d089..fa6779ba 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/common/utils/buffer.py @@ -23,7 +23,7 @@ import torch.nn.functional as F # noqa: N812 from tqdm import tqdm from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -from lerobot.scripts.server.utils import Transition +from lerobot.common.utils.transition import Transition class BatchTransition(TypedDict): diff --git a/lerobot/scripts/server/end_effector_control_utils.py b/lerobot/common/utils/end_effector_control.py similarity index 100% rename from lerobot/scripts/server/end_effector_control_utils.py rename to lerobot/common/utils/end_effector_control.py diff --git a/lerobot/common/utils/process.py b/lerobot/common/utils/process.py new file mode 100644 index 00000000..c5c6241d --- /dev/null +++ b/lerobot/common/utils/process.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import signal +import sys + +shutdown_event_counter = 0 + + +def setup_process_handlers(use_threads: bool) -> any: + if use_threads: + from threading import Event + else: + from multiprocessing import Event + + shutdown_event = Event() + + # Define signal handler + def signal_handler(signum, frame): + logging.info("Shutdown signal received. Cleaning up...") + shutdown_event.set() + global shutdown_event_counter + shutdown_event_counter += 1 + + if shutdown_event_counter > 1: + logging.info("Force shutdown") + sys.exit(1) + + signal.signal(signal.SIGINT, signal_handler) # Ctrl+C + signal.signal(signal.SIGTERM, signal_handler) # Termination request (kill) + signal.signal(signal.SIGHUP, signal_handler) # Terminal closed/Hangup + signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\ + + def signal_handler(signum, frame): + logging.info("Shutdown signal received. Cleaning up...") + shutdown_event.set() + + return shutdown_event diff --git a/lerobot/common/utils/queue.py b/lerobot/common/utils/queue.py new file mode 100644 index 00000000..f285aa07 --- /dev/null +++ b/lerobot/common/utils/queue.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from queue import Empty, Queue + + +def get_last_item_from_queue(queue: Queue): + item = queue.get() + counter = 1 + + # Drain queue and keep only the most recent parameters + try: + while True: + item = queue.get_nowait() + counter += 1 + except Empty: + pass + + logging.debug(f"Drained {counter} items from queue") + + return item diff --git a/lerobot/scripts/server/utils.py b/lerobot/common/utils/transition.py similarity index 68% rename from lerobot/scripts/server/utils.py rename to lerobot/common/utils/transition.py index a9486b6c..a455690f 100644 --- a/lerobot/scripts/server/utils.py +++ b/lerobot/common/utils/transition.py @@ -1,7 +1,6 @@ #!/usr/bin/env python -# Copyright 2024 The HuggingFace Inc. team. -# All rights reserved. +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,64 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging -import signal -import sys -from queue import Empty from typing import TypedDict import torch -from torch.multiprocessing import Queue - -shutdown_event_counter = 0 - - -def setup_process_handlers(use_threads: bool) -> any: - if use_threads: - from threading import Event - else: - from multiprocessing import Event - - shutdown_event = Event() - - # Define signal handler - def signal_handler(signum, frame): - logging.info("Shutdown signal received. Cleaning up...") - shutdown_event.set() - global shutdown_event_counter - shutdown_event_counter += 1 - - if shutdown_event_counter > 1: - logging.info("Force shutdown") - sys.exit(1) - - signal.signal(signal.SIGINT, signal_handler) # Ctrl+C - signal.signal(signal.SIGTERM, signal_handler) # Termination request (kill) - signal.signal(signal.SIGHUP, signal_handler) # Terminal closed/Hangup - signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\ - - def signal_handler(signum, frame): - logging.info("Shutdown signal received. Cleaning up...") - shutdown_event.set() - - return shutdown_event - - -def get_last_item_from_queue(queue: Queue): - item = queue.get() - counter = 1 - - # Drain queue and keep only the most recent parameters - try: - while True: - item = queue.get_nowait() - counter += 1 - except Empty: - pass - - logging.debug(f"Drained {counter} items from queue") - - return item class Transition(TypedDict): diff --git a/lerobot/scripts/server/find_joint_limits.py b/lerobot/scripts/find_joint_limits.py similarity index 100% rename from lerobot/scripts/server/find_joint_limits.py rename to lerobot/scripts/find_joint_limits.py diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/rl/actor.py similarity index 93% rename from lerobot/scripts/server/actor_server.py rename to lerobot/scripts/rl/actor.py index 69c31e80..f9b68ff0 100644 --- a/lerobot/scripts/server/actor_server.py +++ b/lerobot/scripts/rl/actor.py @@ -24,12 +24,12 @@ Examples of usage: - Start an actor server for real robot training with human-in-the-loop intervention: ```bash -python lerobot/scripts/server/actor_server.py --config_path lerobot/configs/train_config_hilserl_so100.json +python lerobot/scripts/rl/actor.py --config_path lerobot/configs/train_config_hilserl_so100.json ``` - Run with a specific robot type for a pick and place task: ```bash -python lerobot/scripts/server/actor_server.py \ +python lerobot/scripts/rl/actor.py \ --config_path lerobot/configs/train_config_hilserl_so100.json \ --robot.type=so100 \ --task=pick_and_place @@ -37,7 +37,7 @@ python lerobot/scripts/server/actor_server.py \ - Set a custom workspace bound for the robot's end-effector: ```bash -python lerobot/scripts/server/actor_server.py \ +python lerobot/scripts/rl/actor.py \ --config_path lerobot/configs/train_config_hilserl_so100.json \ --env.ee_action_space_params.bounds.max="[0.24, 0.20, 0.10]" \ --env.ee_action_space_params.bounds.min="[0.16, -0.08, 0.03]" @@ -45,7 +45,7 @@ python lerobot/scripts/server/actor_server.py \ - Run with specific camera crop parameters: ```bash -python lerobot/scripts/server/actor_server.py \ +python lerobot/scripts/rl/actor.py \ --config_path lerobot/configs/train_config_hilserl_so100.json \ --env.crop_params_dict="{'observation.images.side': [180, 207, 180, 200], 'observation.images.front': [180, 250, 120, 150]}" ``` @@ -85,8 +85,23 @@ from lerobot.common.policies.factory import make_policy from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.robots import so100_follower_end_effector # noqa: F401 from lerobot.common.teleoperators import gamepad, so100_leader # noqa: F401 +from lerobot.common.transport import services_pb2, services_pb2_grpc +from lerobot.common.transport.utils import ( + bytes_to_state_dict, + python_object_to_bytes, + receive_bytes_in_chunks, + send_bytes_in_chunks, + transitions_to_bytes, +) +from lerobot.common.utils.process import setup_process_handlers +from lerobot.common.utils.queue import get_last_item_from_queue from lerobot.common.utils.random_utils import set_seed from lerobot.common.utils.robot_utils import busy_wait +from lerobot.common.utils.transition import ( + Transition, + move_state_dict_to_device, + move_transition_to_device, +) from lerobot.common.utils.utils import ( TimerManager, get_safe_torch_device, @@ -94,22 +109,8 @@ from lerobot.common.utils.utils import ( ) from lerobot.configs import parser from lerobot.configs.train import TrainPipelineConfig -from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service -from lerobot.scripts.server.buffer import Transition -from lerobot.scripts.server.gym_manipulator import make_robot_env -from lerobot.scripts.server.network_utils import ( - bytes_to_state_dict, - python_object_to_bytes, - receive_bytes_in_chunks, - send_bytes_in_chunks, - transitions_to_bytes, -) -from lerobot.scripts.server.utils import ( - get_last_item_from_queue, - move_state_dict_to_device, - move_transition_to_device, - setup_process_handlers, -) +from lerobot.scripts.rl import learner_service +from lerobot.scripts.rl.gym_manipulator import make_robot_env ACTOR_SHUTDOWN_TIMEOUT = 30 @@ -386,14 +387,14 @@ def act_with_policy( def establish_learner_connection( - stub: hilserl_pb2_grpc.LearnerServiceStub, + stub: services_pb2_grpc.LearnerServiceStub, shutdown_event: Event, # type: ignore attempts: int = 30, ): """Establish a connection with the learner. Args: - stub (hilserl_pb2_grpc.LearnerServiceStub): The stub to use for the connection. + stub (services_pb2_grpc.LearnerServiceStub): The stub to use for the connection. shutdown_event (Event): The event to check if the connection should be established. attempts (int): The number of attempts to establish the connection. Returns: @@ -407,7 +408,7 @@ def establish_learner_connection( # Force a connection attempt and check state try: logging.info("[ACTOR] Send ready message to Learner") - if stub.Ready(hilserl_pb2.Empty()) == hilserl_pb2.Empty(): + if stub.Ready(services_pb2.Empty()) == services_pb2.Empty(): return True except grpc.RpcError as e: logging.error(f"[ACTOR] Waiting for Learner to be ready... {e}") @@ -419,7 +420,7 @@ def establish_learner_connection( def learner_service_client( host: str = "127.0.0.1", port: int = 50051, -) -> tuple[hilserl_pb2_grpc.LearnerServiceStub, grpc.Channel]: +) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]: import json """ @@ -458,7 +459,7 @@ def learner_service_client( ("grpc.service_config", service_config_json), ], ) - stub = hilserl_pb2_grpc.LearnerServiceStub(channel) + stub = services_pb2_grpc.LearnerServiceStub(channel) logging.info("[ACTOR] Learner service client created") return stub, channel @@ -467,7 +468,7 @@ def receive_policy( cfg: TrainPipelineConfig, parameters_queue: Queue, shutdown_event: Event, # type: ignore - learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None, + learner_client: services_pb2_grpc.LearnerServiceStub | None = None, grpc_channel: grpc.Channel | None = None, ): """Receive parameters from the learner. @@ -499,7 +500,7 @@ def receive_policy( ) try: - iterator = learner_client.StreamParameters(hilserl_pb2.Empty()) + iterator = learner_client.StreamParameters(services_pb2.Empty()) receive_bytes_in_chunks( iterator, parameters_queue, @@ -519,9 +520,9 @@ def send_transitions( cfg: TrainPipelineConfig, transitions_queue: Queue, shutdown_event: any, # Event, - learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None, + learner_client: services_pb2_grpc.LearnerServiceStub | None = None, grpc_channel: grpc.Channel | None = None, -) -> hilserl_pb2.Empty: +) -> services_pb2.Empty: """ Sends transitions to the learner. @@ -530,7 +531,7 @@ def send_transitions( - Transition Data: - A batch of transitions (observation, action, reward, next observation) is collected. - Transitions are moved to the CPU and serialized using PyTorch. - - The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner. + - The serialized data is wrapped in a `services_pb2.Transition` message and sent to the learner. """ if not use_threads(cfg): @@ -569,9 +570,9 @@ def send_interactions( cfg: TrainPipelineConfig, interactions_queue: Queue, shutdown_event: Event, # type: ignore - learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None, + learner_client: services_pb2_grpc.LearnerServiceStub | None = None, grpc_channel: grpc.Channel | None = None, -) -> hilserl_pb2.Empty: +) -> services_pb2.Empty: """ Sends interactions to the learner. @@ -614,7 +615,7 @@ def send_interactions( logging.info("[ACTOR] Interactions process stopped") -def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilserl_pb2.Empty: # type: ignore +def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> services_pb2.Empty: # type: ignore while not shutdown_event.is_set(): try: message = transitions_queue.get(block=True, timeout=5) @@ -623,16 +624,16 @@ def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilse continue yield from send_bytes_in_chunks( - message, hilserl_pb2.Transition, log_prefix="[ACTOR] Send transitions" + message, services_pb2.Transition, log_prefix="[ACTOR] Send transitions" ) - return hilserl_pb2.Empty() + return services_pb2.Empty() def interactions_stream( shutdown_event: Event, # type: ignore interactions_queue: Queue, -) -> hilserl_pb2.Empty: +) -> services_pb2.Empty: while not shutdown_event.is_set(): try: message = interactions_queue.get(block=True, timeout=5) @@ -642,11 +643,11 @@ def interactions_stream( yield from send_bytes_in_chunks( message, - hilserl_pb2.InteractionMessage, + services_pb2.InteractionMessage, log_prefix="[ACTOR] Send interactions", ) - return hilserl_pb2.Empty() + return services_pb2.Empty() ################################################# diff --git a/lerobot/scripts/server/crop_dataset_roi.py b/lerobot/scripts/rl/crop_dataset_roi.py similarity index 100% rename from lerobot/scripts/server/crop_dataset_roi.py rename to lerobot/scripts/rl/crop_dataset_roi.py diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/rl/gym_manipulator.py similarity index 100% rename from lerobot/scripts/server/gym_manipulator.py rename to lerobot/scripts/rl/gym_manipulator.py diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/rl/learner.py similarity index 98% rename from lerobot/scripts/server/learner_server.py rename to lerobot/scripts/rl/learner.py index fea6ea40..ead64577 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/rl/learner.py @@ -25,12 +25,12 @@ Examples of usage: - Start a learner server for training: ```bash -python lerobot/scripts/server/learner_server.py --config_path lerobot/configs/train_config_hilserl_so100.json +python lerobot/scripts/rl/learner.py --config_path lerobot/configs/train_config_hilserl_so100.json ``` - Run with specific SAC hyperparameters: ```bash -python lerobot/scripts/server/learner_server.py \ +python lerobot/scripts/rl/learner.py \ --config_path lerobot/configs/train_config_hilserl_so100.json \ --learner.sac.alpha=0.1 \ --learner.sac.gamma=0.99 @@ -38,7 +38,7 @@ python lerobot/scripts/server/learner_server.py \ - Run with a specific dataset and wandb logging: ```bash -python lerobot/scripts/server/learner_server.py \ +python lerobot/scripts/rl/learner.py \ --config_path lerobot/configs/train_config_hilserl_so100.json \ --dataset.repo_id=username/pick_lift_cube \ --wandb.enable=true \ @@ -47,14 +47,14 @@ python lerobot/scripts/server/learner_server.py \ - Run with a pretrained policy for fine-tuning: ```bash -python lerobot/scripts/server/learner_server.py \ +python lerobot/scripts/rl/learner.py \ --config_path lerobot/configs/train_config_hilserl_so100.json \ --pretrained_policy_name_or_path=outputs/previous_training/checkpoints/080000/pretrained_model ``` - Run with a reward classifier model: ```bash -python lerobot/scripts/server/learner_server.py \ +python lerobot/scripts/rl/learner.py \ --config_path lerobot/configs/train_config_hilserl_so100.json \ --reward_classifier_pretrained_path=outputs/reward_model/best_model ``` @@ -84,7 +84,6 @@ from pathlib import Path from pprint import pformat import grpc -import hilserl_pb2_grpc # type: ignore import torch from termcolor import colored from torch import nn @@ -103,6 +102,14 @@ from lerobot.common.policies.factory import make_policy from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.robots import so100_follower_end_effector # noqa: F401 from lerobot.common.teleoperators import gamepad, so100_leader # noqa: F401 +from lerobot.common.transport import services_pb2_grpc +from lerobot.common.transport.utils import ( + bytes_to_python_object, + bytes_to_transitions, + state_to_bytes, +) +from lerobot.common.utils.buffer import ReplayBuffer, concatenate_batch_transitions +from lerobot.common.utils.process import setup_process_handlers from lerobot.common.utils.random_utils import set_seed from lerobot.common.utils.train_utils import ( get_step_checkpoint_dir, @@ -112,6 +119,7 @@ from lerobot.common.utils.train_utils import ( from lerobot.common.utils.train_utils import ( load_training_state as utils_load_training_state, ) +from lerobot.common.utils.transition import move_state_dict_to_device, move_transition_to_device from lerobot.common.utils.utils import ( format_big_number, get_safe_torch_device, @@ -120,18 +128,7 @@ from lerobot.common.utils.utils import ( from lerobot.common.utils.wandb_utils import WandBLogger from lerobot.configs import parser from lerobot.configs.train import TrainPipelineConfig -from lerobot.scripts.server import learner_service -from lerobot.scripts.server.buffer import ReplayBuffer, concatenate_batch_transitions -from lerobot.scripts.server.network_utils import ( - bytes_to_python_object, - bytes_to_transitions, - state_to_bytes, -) -from lerobot.scripts.server.utils import ( - move_state_dict_to_device, - move_transition_to_device, - setup_process_handlers, -) +from lerobot.scripts.rl import learner_service LOG_PREFIX = "[LEARNER]" @@ -244,7 +241,7 @@ def start_learner_threads( concurrency_entity = Process communication_process = concurrency_entity( - target=start_learner_server, + target=start_learner, args=( parameters_queue, transition_queue, @@ -643,7 +640,7 @@ def add_actor_information_and_train( ) -def start_learner_server( +def start_learner( parameters_queue: Queue, transition_queue: Queue, interaction_message_queue: Queue, @@ -666,7 +663,7 @@ def start_learner_server( # Create a process-specific log file log_dir = os.path.join(cfg.output_dir, "logs") os.makedirs(log_dir, exist_ok=True) - log_file = os.path.join(log_dir, f"learner_server_process_{os.getpid()}.log") + log_file = os.path.join(log_dir, f"learner_process_{os.getpid()}.log") # Initialize logging with explicit log file init_logging(log_file=log_file, display_pid=True) @@ -693,7 +690,7 @@ def start_learner_server( ], ) - hilserl_pb2_grpc.add_LearnerServiceServicer_to_server( + services_pb2_grpc.add_LearnerServiceServicer_to_server( service, server, ) diff --git a/lerobot/scripts/server/learner_service.py b/lerobot/scripts/rl/learner_service.py similarity index 78% rename from lerobot/scripts/server/learner_service.py rename to lerobot/scripts/rl/learner_service.py index 425611ed..811f3ce1 100644 --- a/lerobot/scripts/server/learner_service.py +++ b/lerobot/scripts/rl/learner_service.py @@ -1,17 +1,21 @@ import logging from multiprocessing import Event, Queue -import hilserl_pb2 # type: ignore -import hilserl_pb2_grpc # type: ignore - -from lerobot.scripts.server.network_utils import receive_bytes_in_chunks, send_bytes_in_chunks +from lerobot.common.transport import services_pb2, services_pb2_grpc +from lerobot.common.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB MAX_WORKERS = 3 # Stream parameters, send transitions and interactions SHUTDOWN_TIMEOUT = 10 -class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer): +class LearnerService(services_pb2_grpc.LearnerServiceServicer): + """ + Implementation of the LearnerService gRPC service + This service is used to send parameters to the Actor and receive transitions and interactions from the Actor + check transport.proto for the gRPC service definition + """ + def __init__( self, shutdown_event: Event, # type: ignore @@ -36,7 +40,7 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer): yield from send_bytes_in_chunks( buffer, - hilserl_pb2.Parameters, + services_pb2.Parameters, log_prefix="[LEARNER] Sending parameters", silent=True, ) @@ -46,7 +50,7 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer): self.shutdown_event.wait(self.seconds_between_pushes) logging.info("[LEARNER] Stream parameters finished") - return hilserl_pb2.Empty() + return services_pb2.Empty() def SendTransitions(self, request_iterator, _context): # noqa: N802 # TODO: authorize the request @@ -60,7 +64,7 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer): ) logging.debug("[LEARNER] Finished receiving transitions") - return hilserl_pb2.Empty() + return services_pb2.Empty() def SendInteractions(self, request_iterator, _context): # noqa: N802 # TODO: authorize the request @@ -74,7 +78,7 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer): ) logging.debug("[LEARNER] Finished receiving interactions") - return hilserl_pb2.Empty() + return services_pb2.Empty() def Ready(self, request, context): # noqa: N802 - return hilserl_pb2.Empty() + return services_pb2.Empty() diff --git a/lerobot/scripts/server/hilserl_pb2.py b/lerobot/scripts/server/hilserl_pb2.py deleted file mode 100644 index 4a4cbea7..00000000 --- a/lerobot/scripts/server/hilserl_pb2.py +++ /dev/null @@ -1,46 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# NO CHECKED-IN PROTOBUF GENCODE -# source: hilserl.proto -# Protobuf Python Version: 5.29.0 -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import runtime_version as _runtime_version -from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder -_runtime_version.ValidateProtobufRuntimeVersion( - _runtime_version.Domain.PUBLIC, - 5, - 29, - 0, - '', - 'hilserl.proto' -) -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"K\n\nTransition\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"K\n\nParameters\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"S\n\x12InteractionMessage\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xc2\x02\n\x0eLearnerService\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty\x12;\n\x10StreamParameters\x12\x0f.hil_serl.Empty\x1a\x14.hil_serl.Parameters0\x01\x12:\n\x0fSendTransitions\x12\x14.hil_serl.Transition\x1a\x0f.hil_serl.Empty(\x01\x12\x43\n\x10SendInteractions\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty(\x01\x12)\n\x05Ready\x12\x0f.hil_serl.Empty\x1a\x0f.hil_serl.Emptyb\x06proto3') - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'hilserl_pb2', _globals) -if not _descriptor._USE_C_DESCRIPTORS: - DESCRIPTOR._loaded_options = None - _globals['_TRANSFERSTATE']._serialized_start=275 - _globals['_TRANSFERSTATE']._serialized_end=371 - _globals['_TRANSITION']._serialized_start=27 - _globals['_TRANSITION']._serialized_end=102 - _globals['_PARAMETERS']._serialized_start=104 - _globals['_PARAMETERS']._serialized_end=179 - _globals['_INTERACTIONMESSAGE']._serialized_start=181 - _globals['_INTERACTIONMESSAGE']._serialized_end=264 - _globals['_EMPTY']._serialized_start=266 - _globals['_EMPTY']._serialized_end=273 - _globals['_LEARNERSERVICE']._serialized_start=374 - _globals['_LEARNERSERVICE']._serialized_end=696 -# @@protoc_insertion_point(module_scope) diff --git a/lerobot/scripts/server/kinematics.py b/lerobot/scripts/server/kinematics.py deleted file mode 100644 index c42d9b2f..00000000 --- a/lerobot/scripts/server/kinematics.py +++ /dev/null @@ -1,546 +0,0 @@ -# ruff: noqa: N806, N815, N803 - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -from scipy.spatial.transform import Rotation - - -def skew_symmetric(w): - """Creates the skew-symmetric matrix from a 3D vector.""" - return np.array([[0, -w[2], w[1]], [w[2], 0, -w[0]], [-w[1], w[0], 0]]) - - -def rodrigues_rotation(w, theta): - """Computes the rotation matrix using Rodrigues' formula.""" - w_hat = skew_symmetric(w) - return np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat - - -def screw_axis_to_transform(S, theta): - """Converts a screw axis to a 4x4 transformation matrix.""" - S_w = S[:3] - S_v = S[3:] - if np.allclose(S_w, 0) and np.linalg.norm(S_v) == 1: # Pure translation - T = np.eye(4) - T[:3, 3] = S_v * theta - elif np.linalg.norm(S_w) == 1: # Rotation and translation - w_hat = skew_symmetric(S_w) - R = np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat - t = (np.eye(3) * theta + (1 - np.cos(theta)) * w_hat + (theta - np.sin(theta)) * w_hat @ w_hat) @ S_v - T = np.eye(4) - T[:3, :3] = R - T[:3, 3] = t - else: - raise ValueError("Invalid screw axis parameters") - return T - - -def pose_difference_se3(pose1, pose2): - """ - Calculates the SE(3) difference between two 4x4 homogeneous transformation matrices. - SE(3) (Special Euclidean Group) represents rigid body transformations in 3D space, combining rotation (SO(3)) and translation. - Each 4x4 matrix has the following structure, a 3x3 rotation matrix in the top-left and a 3x1 translation vector in the top-right: - - [R11 R12 R13 tx] - [R21 R22 R23 ty] - [R31 R32 R33 tz] - [ 0 0 0 1] - - where Rij is the 3x3 rotation matrix and [tx,ty,tz] is the translation vector. - - pose1 - pose2 - - Args: - pose1: A 4x4 numpy array representing the first pose. - pose2: A 4x4 numpy array representing the second pose. - - Returns: - A tuple (translation_diff, rotation_diff) where: - - translation_diff is a 3x1 numpy array representing the translational difference. - - rotation_diff is a 3x1 numpy array representing the rotational difference in axis-angle representation. - """ - - # Extract rotation matrices from poses - R1 = pose1[:3, :3] - R2 = pose2[:3, :3] - - # Calculate translational difference - translation_diff = pose1[:3, 3] - pose2[:3, 3] - - # Calculate rotational difference using scipy's Rotation library - R_diff = Rotation.from_matrix(R1 @ R2.T) - rotation_diff = R_diff.as_rotvec() # Convert to axis-angle representation - - return np.concatenate([translation_diff, rotation_diff]) - - -def se3_error(target_pose, current_pose): - pos_error = target_pose[:3, 3] - current_pose[:3, 3] - R_target = target_pose[:3, :3] - R_current = current_pose[:3, :3] - R_error = R_target @ R_current.T - rot_error = Rotation.from_matrix(R_error).as_rotvec() - return np.concatenate([pos_error, rot_error]) - - -class RobotKinematics: - """Robot kinematics class supporting multiple robot models.""" - - # Robot measurements dictionary - ROBOT_MEASUREMENTS = { - "koch": { - "gripper": [0.239, -0.001, 0.024], - "wrist": [0.209, 0, 0.024], - "forearm": [0.108, 0, 0.02], - "humerus": [0, 0, 0.036], - "shoulder": [0, 0, 0], - "base": [0, 0, 0.02], - }, - "so100": { - "gripper": [0.320, 0, 0.050], - "wrist": [0.278, 0, 0.050], - "forearm": [0.143, 0, 0.044], - "humerus": [0.031, 0, 0.072], - "shoulder": [0, 0, 0], - "base": [0, 0, 0.02], - }, - "moss": { - "gripper": [0.246, 0.013, 0.111], - "wrist": [0.245, 0.002, 0.064], - "forearm": [0.122, 0, 0.064], - "humerus": [0.001, 0.001, 0.063], - "shoulder": [0, 0, 0], - "base": [0, 0, 0.02], - }, - } - - def __init__(self, robot_type="so100"): - """Initialize kinematics for the specified robot type. - - Args: - robot_type: String specifying the robot model ("koch", "so100", or "moss") - """ - if robot_type not in self.ROBOT_MEASUREMENTS: - raise ValueError( - f"Unknown robot type: {robot_type}. Available types: {list(self.ROBOT_MEASUREMENTS.keys())}" - ) - - self.robot_type = robot_type - self.measurements = self.ROBOT_MEASUREMENTS[robot_type] - - # Initialize all transformation matrices and screw axes - self._setup_transforms() - - def _create_translation_matrix(self, x=0, y=0, z=0): - """Create a 4x4 translation matrix.""" - return np.array([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z], [0, 0, 0, 1]]) - - def _setup_transforms(self): - """Setup all transformation matrices and screw axes for the robot.""" - # Set up rotation matrices (constant across robot types) - - # Gripper orientation - self.gripper_X0 = np.array( - [ - [1, 0, 0, 0], - [0, 0, 1, 0], - [0, -1, 0, 0], - [0, 0, 0, 1], - ] - ) - - # Wrist orientation - self.wrist_X0 = np.array( - [ - [0, -1, 0, 0], - [1, 0, 0, 0], - [0, 0, 1, 0], - [0, 0, 0, 1], - ] - ) - - # Base orientation - self.base_X0 = np.array( - [ - [0, 0, 1, 0], - [1, 0, 0, 0], - [0, 1, 0, 0], - [0, 0, 0, 1], - ] - ) - - # Gripper - # Screw axis of gripper frame wrt base frame - self.S_BG = np.array( - [ - 1, - 0, - 0, - 0, - self.measurements["gripper"][2], - -self.measurements["gripper"][1], - ] - ) - - # Gripper origin to centroid transform - self.X_GoGc = self._create_translation_matrix(x=0.07) - - # Gripper origin to tip transform - self.X_GoGt = self._create_translation_matrix(x=0.12) - - # 0-position gripper frame pose wrt base - self.X_BoGo = self._create_translation_matrix( - x=self.measurements["gripper"][0], - y=self.measurements["gripper"][1], - z=self.measurements["gripper"][2], - ) - - # Wrist - # Screw axis of wrist frame wrt base frame - self.S_BR = np.array([0, 1, 0, -self.measurements["wrist"][2], 0, self.measurements["wrist"][0]]) - - # 0-position origin to centroid transform - self.X_RoRc = self._create_translation_matrix(x=0.0035, y=-0.002) - - # 0-position wrist frame pose wrt base - self.X_BR = self._create_translation_matrix( - x=self.measurements["wrist"][0], - y=self.measurements["wrist"][1], - z=self.measurements["wrist"][2], - ) - - # Forearm - # Screw axis of forearm frame wrt base frame - self.S_BF = np.array( - [ - 0, - 1, - 0, - -self.measurements["forearm"][2], - 0, - self.measurements["forearm"][0], - ] - ) - - # Forearm origin + centroid transform - self.X_FoFc = self._create_translation_matrix(x=0.036) # spellchecker:disable-line - - # 0-position forearm frame pose wrt base - self.X_BF = self._create_translation_matrix( - x=self.measurements["forearm"][0], - y=self.measurements["forearm"][1], - z=self.measurements["forearm"][2], - ) - - # Humerus - # Screw axis of humerus frame wrt base frame - self.S_BH = np.array( - [ - 0, - -1, - 0, - self.measurements["humerus"][2], - 0, - -self.measurements["humerus"][0], - ] - ) - - # Humerus origin to centroid transform - self.X_HoHc = self._create_translation_matrix(x=0.0475) - - # 0-position humerus frame pose wrt base - self.X_BH = self._create_translation_matrix( - x=self.measurements["humerus"][0], - y=self.measurements["humerus"][1], - z=self.measurements["humerus"][2], - ) - - # Shoulder - # Screw axis of shoulder frame wrt Base frame - self.S_BS = np.array([0, 0, -1, 0, 0, 0]) - - # Shoulder origin to centroid transform - self.X_SoSc = self._create_translation_matrix(x=-0.017, z=0.0235) - - # 0-position shoulder frame pose wrt base - self.X_BS = self._create_translation_matrix( - x=self.measurements["shoulder"][0], - y=self.measurements["shoulder"][1], - z=self.measurements["shoulder"][2], - ) - - # Base - # Base origin to centroid transform - self.X_BoBc = self._create_translation_matrix(y=0.015) - - # World to base transform - self.X_WoBo = self._create_translation_matrix( - x=self.measurements["base"][0], - y=self.measurements["base"][1], - z=self.measurements["base"][2], - ) - - # Pre-compute gripper post-multiplication matrix - self._fk_gripper_post = self.X_GoGc @ self.X_BoGo @ self.gripper_X0 - - def fk_base(self): - """Forward kinematics for the base frame.""" - return self.X_WoBo @ self.X_BoBc @ self.base_X0 - - def fk_shoulder(self, robot_pos_deg): - """Forward kinematics for the shoulder frame.""" - robot_pos_rad = robot_pos_deg / 180 * np.pi - return self.X_WoBo @ screw_axis_to_transform(self.S_BS, robot_pos_rad[0]) @ self.X_SoSc @ self.X_BS - - def fk_humerus(self, robot_pos_deg): - """Forward kinematics for the humerus frame.""" - robot_pos_rad = robot_pos_deg / 180 * np.pi - return ( - self.X_WoBo - @ screw_axis_to_transform(self.S_BS, robot_pos_rad[0]) - @ screw_axis_to_transform(self.S_BH, robot_pos_rad[1]) - @ self.X_HoHc - @ self.X_BH - ) - - def fk_forearm(self, robot_pos_deg): - """Forward kinematics for the forearm frame.""" - robot_pos_rad = robot_pos_deg / 180 * np.pi - return ( - self.X_WoBo - @ screw_axis_to_transform(self.S_BS, robot_pos_rad[0]) - @ screw_axis_to_transform(self.S_BH, robot_pos_rad[1]) - @ screw_axis_to_transform(self.S_BF, robot_pos_rad[2]) - @ self.X_FoFc # spellchecker:disable-line - @ self.X_BF - ) - - def fk_wrist(self, robot_pos_deg): - """Forward kinematics for the wrist frame.""" - robot_pos_rad = robot_pos_deg / 180 * np.pi - return ( - self.X_WoBo - @ screw_axis_to_transform(self.S_BS, robot_pos_rad[0]) - @ screw_axis_to_transform(self.S_BH, robot_pos_rad[1]) - @ screw_axis_to_transform(self.S_BF, robot_pos_rad[2]) - @ screw_axis_to_transform(self.S_BR, robot_pos_rad[3]) - @ self.X_RoRc - @ self.X_BR - @ self.wrist_X0 - ) - - def fk_gripper(self, robot_pos_deg): - """Forward kinematics for the gripper frame.""" - robot_pos_rad = robot_pos_deg / 180 * np.pi - return ( - self.X_WoBo - @ screw_axis_to_transform(self.S_BS, robot_pos_rad[0]) - @ screw_axis_to_transform(self.S_BH, robot_pos_rad[1]) - @ screw_axis_to_transform(self.S_BF, robot_pos_rad[2]) - @ screw_axis_to_transform(self.S_BR, robot_pos_rad[3]) - @ screw_axis_to_transform(self.S_BG, robot_pos_rad[4]) - @ self._fk_gripper_post - ) - - def fk_gripper_tip(self, robot_pos_deg): - """Forward kinematics for the gripper tip frame.""" - robot_pos_rad = robot_pos_deg / 180 * np.pi - return ( - self.X_WoBo - @ screw_axis_to_transform(self.S_BS, robot_pos_rad[0]) - @ screw_axis_to_transform(self.S_BH, robot_pos_rad[1]) - @ screw_axis_to_transform(self.S_BF, robot_pos_rad[2]) - @ screw_axis_to_transform(self.S_BR, robot_pos_rad[3]) - @ screw_axis_to_transform(self.S_BG, robot_pos_rad[4]) - @ self.X_GoGt - @ self.X_BoGo - @ self.gripper_X0 - ) - - def compute_jacobian(self, robot_pos_deg, fk_func=None): - """Finite differences to compute the Jacobian. - J(i, j) represents how the ith component of the end-effector's velocity changes wrt a small change - in the jth joint's velocity. - - Args: - robot_pos_deg: Current joint positions in degrees - fk_func: Forward kinematics function to use (defaults to fk_gripper) - """ - if fk_func is None: - fk_func = self.fk_gripper - - eps = 1e-8 - jac = np.zeros(shape=(6, 5)) - delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64) - for el_ix in range(len(robot_pos_deg[:-1])): - delta *= 0 - delta[el_ix] = eps / 2 - Sdot = ( - pose_difference_se3( - fk_func(robot_pos_deg[:-1] + delta), - fk_func(robot_pos_deg[:-1] - delta), - ) - / eps - ) - jac[:, el_ix] = Sdot - return jac - - def compute_positional_jacobian(self, robot_pos_deg, fk_func=None): - """Finite differences to compute the positional Jacobian. - J(i, j) represents how the ith component of the end-effector's position changes wrt a small change - in the jth joint's velocity. - - Args: - robot_pos_deg: Current joint positions in degrees - fk_func: Forward kinematics function to use (defaults to fk_gripper) - """ - if fk_func is None: - fk_func = self.fk_gripper - - eps = 1e-8 - jac = np.zeros(shape=(3, 5)) - delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64) - for el_ix in range(len(robot_pos_deg[:-1])): - delta *= 0 - delta[el_ix] = eps / 2 - Sdot = ( - fk_func(robot_pos_deg[:-1] + delta)[:3, 3] - fk_func(robot_pos_deg[:-1] - delta)[:3, 3] - ) / eps - jac[:, el_ix] = Sdot - return jac - - def ik(self, current_joint_state, desired_ee_pose, position_only=True, fk_func=None): - """Inverse kinematics using gradient descent. - - Args: - current_joint_state: Initial joint positions in degrees - desired_ee_pose: Target end-effector pose as a 4x4 transformation matrix - position_only: If True, only match end-effector position, not orientation - fk_func: Forward kinematics function to use (defaults to fk_gripper) - - Returns: - Joint positions in degrees that achieve the desired end-effector pose - """ - if fk_func is None: - fk_func = self.fk_gripper - - # Do gradient descent. - max_iterations = 5 - learning_rate = 1 - for _ in range(max_iterations): - current_ee_pose = fk_func(current_joint_state) - if not position_only: - error = se3_error(desired_ee_pose, current_ee_pose) - jac = self.compute_jacobian(current_joint_state, fk_func) - else: - error = desired_ee_pose[:3, 3] - current_ee_pose[:3, 3] - jac = self.compute_positional_jacobian(current_joint_state, fk_func) - delta_angles = np.linalg.pinv(jac) @ error - current_joint_state[:-1] += learning_rate * delta_angles - - if np.linalg.norm(error) < 5e-3: - return current_joint_state - return current_joint_state - - -if __name__ == "__main__": - import time - - def run_test(robot_type): - """Run test suite for a specific robot type.""" - print(f"\n--- Testing {robot_type.upper()} Robot ---") - - # Initialize kinematics for this robot - robot = RobotKinematics(robot_type) - - # Test 1: Forward kinematics consistency - print("Test 1: Forward kinematics consistency") - test_angles = np.array([30, 45, -30, 20, 10, 0]) # Example joint angles in degrees - - # Calculate FK for different joints - shoulder_pose = robot.fk_shoulder(test_angles) - humerus_pose = robot.fk_humerus(test_angles) - forearm_pose = robot.fk_forearm(test_angles) - wrist_pose = robot.fk_wrist(test_angles) - gripper_pose = robot.fk_gripper(test_angles) - gripper_tip_pose = robot.fk_gripper_tip(test_angles) - - # Check that poses form a consistent kinematic chain (positions should be progressively further from origin) - distances = [ - np.linalg.norm(shoulder_pose[:3, 3]), - np.linalg.norm(humerus_pose[:3, 3]), - np.linalg.norm(forearm_pose[:3, 3]), - np.linalg.norm(wrist_pose[:3, 3]), - np.linalg.norm(gripper_pose[:3, 3]), - np.linalg.norm(gripper_tip_pose[:3, 3]), - ] - - # Check if distances generally increase along the chain - is_consistent = all(distances[i] <= distances[i + 1] for i in range(len(distances) - 1)) - print(f" Pose distances from origin: {[round(d, 3) for d in distances]}") - print(f" Kinematic chain consistency: {'PASSED' if is_consistent else 'FAILED'}") - - # Test 2: Jacobian computation - print("Test 2: Jacobian computation") - jacobian = robot.compute_jacobian(test_angles) - positional_jacobian = robot.compute_positional_jacobian(test_angles) - - # Check shapes - jacobian_shape_ok = jacobian.shape == (6, 5) - pos_jacobian_shape_ok = positional_jacobian.shape == (3, 5) - - print(f" Jacobian shape: {'PASSED' if jacobian_shape_ok else 'FAILED'}") - print(f" Positional Jacobian shape: {'PASSED' if pos_jacobian_shape_ok else 'FAILED'}") - - # Test 3: Inverse kinematics - print("Test 3: Inverse kinematics (position only)") - - # Generate target pose from known joint angles - original_angles = np.array([10, 20, 30, -10, 5, 0]) - target_pose = robot.fk_gripper(original_angles) - - # Start IK from a different position - initial_guess = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) - - # Measure IK performance - start_time = time.time() - computed_angles = robot.ik(initial_guess.copy(), target_pose) - ik_time = time.time() - start_time - - # Compute resulting pose from IK solution - result_pose = robot.fk_gripper(computed_angles) - - # Calculate position error - pos_error = np.linalg.norm(target_pose[:3, 3] - result_pose[:3, 3]) - passed = pos_error < 0.01 # Accept errors less than 1cm - - print(f" IK computation time: {ik_time:.4f} seconds") - print(f" Position error: {pos_error:.4f}") - print(f" IK position accuracy: {'PASSED' if passed else 'FAILED'}") - - return is_consistent and jacobian_shape_ok and pos_jacobian_shape_ok and passed - - # Run tests for all robot types - results = {} - for robot_type in ["koch", "so100", "moss"]: - results[robot_type] = run_test(robot_type) - - # Print overall summary - print("\n=== Test Summary ===") - all_passed = all(results.values()) - for robot_type, passed in results.items(): - print(f"{robot_type.upper()}: {'PASSED' if passed else 'FAILED'}") - print(f"\nOverall: {'ALL TESTS PASSED' if all_passed else 'SOME TESTS FAILED'}") diff --git a/tests/server/test_replay_buffer.py b/tests/utils/test_replay_buffer.py similarity index 99% rename from tests/server/test_replay_buffer.py rename to tests/utils/test_replay_buffer.py index 5d1cd62f..499f67b7 100644 --- a/tests/server/test_replay_buffer.py +++ b/tests/utils/test_replay_buffer.py @@ -5,7 +5,7 @@ import pytest import torch from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -from lerobot.scripts.server.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized +from lerobot.common.utils.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized from tests.fixtures.constants import DUMMY_REPO_ID