From ce2b9724bfe1b5a4c45e61b1890eef3f5ab0909c Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Wed, 9 Jul 2025 16:22:40 +0200 Subject: [PATCH] fix(hil-serl): discrete critic send through network (#1468) Co-authored-by: Khalil Meftah Co-authored-by: jpizarrom --- pyproject.toml | 2 +- src/lerobot/scripts/rl/actor.py | 28 ++++++++++-- src/lerobot/scripts/rl/learner.py | 14 +++++- src/lerobot/transport/services.proto | 4 +- src/lerobot/transport/services_pb2.py | 32 ++++++------- src/lerobot/transport/services_pb2_grpc.py | 52 +++++++++++----------- 6 files changed, 81 insertions(+), 51 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 408e3b77..e13a9af0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,7 +79,7 @@ dependencies = [ [project.optional-dependencies] aloha = ["gym-aloha>=0.1.1 ; python_version < '4.0'"] docs = ["hf-doc-builder @ git+https://github.com/huggingface/doc-builder.git@main", "watchdog >= 6.0.0"] -dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1"] +dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "grpcio-tools==1.71.0"] dora = [ "gym-dora @ git+https://github.com/dora-rs/dora-lerobot.git#subdirectory=gym_dora ; python_version < '4.0'", ] diff --git a/src/lerobot/scripts/rl/actor.py b/src/lerobot/scripts/rl/actor.py index 0e96d335..cd5e286c 100644 --- a/src/lerobot/scripts/rl/actor.py +++ b/src/lerobot/scripts/rl/actor.py @@ -317,7 +317,7 @@ def act_with_policy( if done or truncated: logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}") - update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device) + update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device) if len(list_transition_to_send_to_learner) > 0: push_transitions_to_transport_queue( @@ -642,9 +642,29 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device) bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False) if bytes_state_dict is not None: logging.info("[ACTOR] Load new parameters from Learner.") - state_dict = bytes_to_state_dict(bytes_state_dict) - state_dict = move_state_dict_to_device(state_dict, device=device) - policy.load_state_dict(state_dict) + state_dicts = bytes_to_state_dict(bytes_state_dict) + + # TODO: check encoder parameter synchronization possible issues: + # 1. When shared_encoder=True, we're loading stale encoder params from actor's state_dict + # instead of the updated encoder params from critic (which is optimized separately) + # 2. When freeze_vision_encoder=True, we waste bandwidth sending/loading frozen params + # 3. Need to handle encoder params correctly for both actor and discrete_critic + # Potential fixes: + # - Send critic's encoder state when shared_encoder=True + # - Skip encoder params entirely when freeze_vision_encoder=True + # - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic) + + # Load actor state dict + actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device) + policy.actor.load_state_dict(actor_state_dict) + + # Load discrete critic if present + if hasattr(policy, "discrete_critic") and "discrete_critic" in state_dicts: + discrete_critic_state_dict = move_state_dict_to_device( + state_dicts["discrete_critic"], device=device + ) + policy.discrete_critic.load_state_dict(discrete_critic_state_dict) + logging.info("[ACTOR] Loaded discrete critic parameters from Learner.") ################################################# diff --git a/src/lerobot/scripts/rl/learner.py b/src/lerobot/scripts/rl/learner.py index d8830d83..edd2363b 100644 --- a/src/lerobot/scripts/rl/learner.py +++ b/src/lerobot/scripts/rl/learner.py @@ -1109,8 +1109,18 @@ def check_nan_in_transition( def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module): logging.debug("[LEARNER] Pushing actor policy to the queue") - state_dict = move_state_dict_to_device(policy.actor.state_dict(), device="cpu") - state_bytes = state_to_bytes(state_dict) + + # Create a dictionary to hold all the state dicts + state_dicts = {"policy": move_state_dict_to_device(policy.actor.state_dict(), device="cpu")} + + # Add discrete critic if it exists + if hasattr(policy, "discrete_critic") and policy.discrete_critic is not None: + state_dicts["discrete_critic"] = move_state_dict_to_device( + policy.discrete_critic.state_dict(), device="cpu" + ) + logging.debug("[LEARNER] Including discrete critic in state dict push") + + state_bytes = state_to_bytes(state_dicts) parameters_queue.put(state_bytes) diff --git a/src/lerobot/transport/services.proto b/src/lerobot/transport/services.proto index 89bfc107..70f39741 100644 --- a/src/lerobot/transport/services.proto +++ b/src/lerobot/transport/services.proto @@ -11,11 +11,11 @@ // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and -// limitations under the License. +// limitations under the License.python -m grpc_tools.protoc -I src --python_out=src --grpc_python_out=src src/lerobot/transport/services.proto // To generate a classes for transport part (services_pb2.py and services_pb2_grpc.py) use the following command: // -// python -m grpc_tools.protoc -I . --python_out=. --grpc_python_out=. src/lerobot/transport/services.proto +// python -m grpc_tools.protoc -I src --python_out=src --grpc_python_out=src src/lerobot/transport/services.proto // // The command should be launched from the root of the project. diff --git a/src/lerobot/transport/services_pb2.py b/src/lerobot/transport/services_pb2.py index 8a213768..9e66ae1e 100644 --- a/src/lerobot/transport/services_pb2.py +++ b/src/lerobot/transport/services_pb2.py @@ -1,6 +1,6 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! # NO CHECKED-IN PROTOBUF GENCODE -# source: src/lerobot/transport/services.proto +# source: lerobot/transport/services.proto # Protobuf Python Version: 5.29.0 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor @@ -14,7 +14,7 @@ _runtime_version.ValidateProtobufRuntimeVersion( 29, 0, '', - 'src/lerobot/transport/services.proto' + 'lerobot/transport/services.proto' ) # @@protoc_insertion_point(imports) @@ -23,23 +23,23 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$src/lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'src.lerobot.transport.services_pb2', _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'lerobot.transport.services_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals['_TRANSFERSTATE']._serialized_start=302 - _globals['_TRANSFERSTATE']._serialized_end=398 - _globals['_TRANSITION']._serialized_start=51 - _globals['_TRANSITION']._serialized_end=127 - _globals['_PARAMETERS']._serialized_start=129 - _globals['_PARAMETERS']._serialized_end=205 - _globals['_INTERACTIONMESSAGE']._serialized_start=207 - _globals['_INTERACTIONMESSAGE']._serialized_end=291 - _globals['_EMPTY']._serialized_start=293 - _globals['_EMPTY']._serialized_end=300 - _globals['_LEARNERSERVICE']._serialized_start=401 - _globals['_LEARNERSERVICE']._serialized_end=658 + _globals['_TRANSFERSTATE']._serialized_start=298 + _globals['_TRANSFERSTATE']._serialized_end=394 + _globals['_TRANSITION']._serialized_start=47 + _globals['_TRANSITION']._serialized_end=123 + _globals['_PARAMETERS']._serialized_start=125 + _globals['_PARAMETERS']._serialized_end=201 + _globals['_INTERACTIONMESSAGE']._serialized_start=203 + _globals['_INTERACTIONMESSAGE']._serialized_end=287 + _globals['_EMPTY']._serialized_start=289 + _globals['_EMPTY']._serialized_end=296 + _globals['_LEARNERSERVICE']._serialized_start=397 + _globals['_LEARNERSERVICE']._serialized_end=654 # @@protoc_insertion_point(module_scope) diff --git a/src/lerobot/transport/services_pb2_grpc.py b/src/lerobot/transport/services_pb2_grpc.py index a4fe8c57..77801a34 100644 --- a/src/lerobot/transport/services_pb2_grpc.py +++ b/src/lerobot/transport/services_pb2_grpc.py @@ -3,7 +3,7 @@ import grpc import warnings -from src.lerobot.transport import services_pb2 as src_dot_lerobot_dot_transport_dot_services__pb2 +from lerobot.transport import services_pb2 as lerobot_dot_transport_dot_services__pb2 GRPC_GENERATED_VERSION = '1.71.0' GRPC_VERSION = grpc.__version__ @@ -18,7 +18,7 @@ except ImportError: if _version_not_supported: raise RuntimeError( f'The grpc package installed is at version {GRPC_VERSION},' - + f' but the generated code in src/lerobot/transport/services_pb2_grpc.py depends on' + + f' but the generated code in lerobot/transport/services_pb2_grpc.py depends on' + f' grpcio>={GRPC_GENERATED_VERSION}.' + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' @@ -38,23 +38,23 @@ class LearnerServiceStub: """ self.StreamParameters = channel.unary_stream( '/transport.LearnerService/StreamParameters', - request_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, - response_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Parameters.FromString, + request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + response_deserializer=lerobot_dot_transport_dot_services__pb2.Parameters.FromString, _registered_method=True) self.SendTransitions = channel.stream_unary( '/transport.LearnerService/SendTransitions', - request_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Transition.SerializeToString, - response_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString, + request_serializer=lerobot_dot_transport_dot_services__pb2.Transition.SerializeToString, + response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString, _registered_method=True) self.SendInteractions = channel.stream_unary( '/transport.LearnerService/SendInteractions', - request_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString, - response_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString, + request_serializer=lerobot_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString, + response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString, _registered_method=True) self.Ready = channel.unary_unary( '/transport.LearnerService/Ready', - request_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, - response_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString, + request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString, _registered_method=True) @@ -93,23 +93,23 @@ def add_LearnerServiceServicer_to_server(servicer, server): rpc_method_handlers = { 'StreamParameters': grpc.unary_stream_rpc_method_handler( servicer.StreamParameters, - request_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString, - response_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Parameters.SerializeToString, + request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString, + response_serializer=lerobot_dot_transport_dot_services__pb2.Parameters.SerializeToString, ), 'SendTransitions': grpc.stream_unary_rpc_method_handler( servicer.SendTransitions, - request_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Transition.FromString, - response_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + request_deserializer=lerobot_dot_transport_dot_services__pb2.Transition.FromString, + response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, ), 'SendInteractions': grpc.stream_unary_rpc_method_handler( servicer.SendInteractions, - request_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.InteractionMessage.FromString, - response_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + request_deserializer=lerobot_dot_transport_dot_services__pb2.InteractionMessage.FromString, + response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, ), 'Ready': grpc.unary_unary_rpc_method_handler( servicer.Ready, - request_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString, - response_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString, + response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( @@ -139,8 +139,8 @@ class LearnerService: request, target, '/transport.LearnerService/StreamParameters', - src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, - src_dot_lerobot_dot_transport_dot_services__pb2.Parameters.FromString, + lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + lerobot_dot_transport_dot_services__pb2.Parameters.FromString, options, channel_credentials, insecure, @@ -166,8 +166,8 @@ class LearnerService: request_iterator, target, '/transport.LearnerService/SendTransitions', - src_dot_lerobot_dot_transport_dot_services__pb2.Transition.SerializeToString, - src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString, + lerobot_dot_transport_dot_services__pb2.Transition.SerializeToString, + lerobot_dot_transport_dot_services__pb2.Empty.FromString, options, channel_credentials, insecure, @@ -193,8 +193,8 @@ class LearnerService: request_iterator, target, '/transport.LearnerService/SendInteractions', - src_dot_lerobot_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString, - src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString, + lerobot_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString, + lerobot_dot_transport_dot_services__pb2.Empty.FromString, options, channel_credentials, insecure, @@ -220,8 +220,8 @@ class LearnerService: request, target, '/transport.LearnerService/Ready', - src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, - src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString, + lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + lerobot_dot_transport_dot_services__pb2.Empty.FromString, options, channel_credentials, insecure,