fix(hil-serl): discrete critic send through network (#1468)

Co-authored-by: Khalil Meftah <kmeftah.khalil@gmail.com>
Co-authored-by: jpizarrom <jpizarrom@gmail.com>
This commit is contained in:
Adil Zouitine
2025-07-09 16:22:40 +02:00
committed by GitHub
parent cf86b9300d
commit ce2b9724bf
6 changed files with 81 additions and 51 deletions

View File

@@ -79,7 +79,7 @@ dependencies = [
[project.optional-dependencies]
aloha = ["gym-aloha>=0.1.1 ; python_version < '4.0'"]
docs = ["hf-doc-builder @ git+https://github.com/huggingface/doc-builder.git@main", "watchdog >= 6.0.0"]
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1"]
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "grpcio-tools==1.71.0"]
dora = [
"gym-dora @ git+https://github.com/dora-rs/dora-lerobot.git#subdirectory=gym_dora ; python_version < '4.0'",
]

View File

@@ -317,7 +317,7 @@ def act_with_policy(
if done or truncated:
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device)
if len(list_transition_to_send_to_learner) > 0:
push_transitions_to_transport_queue(
@@ -642,9 +642,29 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device)
bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False)
if bytes_state_dict is not None:
logging.info("[ACTOR] Load new parameters from Learner.")
state_dict = bytes_to_state_dict(bytes_state_dict)
state_dict = move_state_dict_to_device(state_dict, device=device)
policy.load_state_dict(state_dict)
state_dicts = bytes_to_state_dict(bytes_state_dict)
# TODO: check encoder parameter synchronization possible issues:
# 1. When shared_encoder=True, we're loading stale encoder params from actor's state_dict
# instead of the updated encoder params from critic (which is optimized separately)
# 2. When freeze_vision_encoder=True, we waste bandwidth sending/loading frozen params
# 3. Need to handle encoder params correctly for both actor and discrete_critic
# Potential fixes:
# - Send critic's encoder state when shared_encoder=True
# - Skip encoder params entirely when freeze_vision_encoder=True
# - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic)
# Load actor state dict
actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device)
policy.actor.load_state_dict(actor_state_dict)
# Load discrete critic if present
if hasattr(policy, "discrete_critic") and "discrete_critic" in state_dicts:
discrete_critic_state_dict = move_state_dict_to_device(
state_dicts["discrete_critic"], device=device
)
policy.discrete_critic.load_state_dict(discrete_critic_state_dict)
logging.info("[ACTOR] Loaded discrete critic parameters from Learner.")
#################################################

View File

@@ -1109,8 +1109,18 @@ def check_nan_in_transition(
def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
logging.debug("[LEARNER] Pushing actor policy to the queue")
state_dict = move_state_dict_to_device(policy.actor.state_dict(), device="cpu")
state_bytes = state_to_bytes(state_dict)
# Create a dictionary to hold all the state dicts
state_dicts = {"policy": move_state_dict_to_device(policy.actor.state_dict(), device="cpu")}
# Add discrete critic if it exists
if hasattr(policy, "discrete_critic") and policy.discrete_critic is not None:
state_dicts["discrete_critic"] = move_state_dict_to_device(
policy.discrete_critic.state_dict(), device="cpu"
)
logging.debug("[LEARNER] Including discrete critic in state dict push")
state_bytes = state_to_bytes(state_dicts)
parameters_queue.put(state_bytes)

View File

@@ -11,11 +11,11 @@
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.python -m grpc_tools.protoc -I src --python_out=src --grpc_python_out=src src/lerobot/transport/services.proto
// To generate a classes for transport part (services_pb2.py and services_pb2_grpc.py) use the following command:
//
// python -m grpc_tools.protoc -I . --python_out=. --grpc_python_out=. src/lerobot/transport/services.proto
// python -m grpc_tools.protoc -I src --python_out=src --grpc_python_out=src src/lerobot/transport/services.proto
//
// The command should be launched from the root of the project.

View File

@@ -1,6 +1,6 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: src/lerobot/transport/services.proto
# source: lerobot/transport/services.proto
# Protobuf Python Version: 5.29.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
@@ -14,7 +14,7 @@ _runtime_version.ValidateProtobufRuntimeVersion(
29,
0,
'',
'src/lerobot/transport/services.proto'
'lerobot/transport/services.proto'
)
# @@protoc_insertion_point(imports)
@@ -23,23 +23,23 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$src/lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'src.lerobot.transport.services_pb2', _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'lerobot.transport.services_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_TRANSFERSTATE']._serialized_start=302
_globals['_TRANSFERSTATE']._serialized_end=398
_globals['_TRANSITION']._serialized_start=51
_globals['_TRANSITION']._serialized_end=127
_globals['_PARAMETERS']._serialized_start=129
_globals['_PARAMETERS']._serialized_end=205
_globals['_INTERACTIONMESSAGE']._serialized_start=207
_globals['_INTERACTIONMESSAGE']._serialized_end=291
_globals['_EMPTY']._serialized_start=293
_globals['_EMPTY']._serialized_end=300
_globals['_LEARNERSERVICE']._serialized_start=401
_globals['_LEARNERSERVICE']._serialized_end=658
_globals['_TRANSFERSTATE']._serialized_start=298
_globals['_TRANSFERSTATE']._serialized_end=394
_globals['_TRANSITION']._serialized_start=47
_globals['_TRANSITION']._serialized_end=123
_globals['_PARAMETERS']._serialized_start=125
_globals['_PARAMETERS']._serialized_end=201
_globals['_INTERACTIONMESSAGE']._serialized_start=203
_globals['_INTERACTIONMESSAGE']._serialized_end=287
_globals['_EMPTY']._serialized_start=289
_globals['_EMPTY']._serialized_end=296
_globals['_LEARNERSERVICE']._serialized_start=397
_globals['_LEARNERSERVICE']._serialized_end=654
# @@protoc_insertion_point(module_scope)

View File

@@ -3,7 +3,7 @@
import grpc
import warnings
from src.lerobot.transport import services_pb2 as src_dot_lerobot_dot_transport_dot_services__pb2
from lerobot.transport import services_pb2 as lerobot_dot_transport_dot_services__pb2
GRPC_GENERATED_VERSION = '1.71.0'
GRPC_VERSION = grpc.__version__
@@ -18,7 +18,7 @@ except ImportError:
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in src/lerobot/transport/services_pb2_grpc.py depends on'
+ f' but the generated code in lerobot/transport/services_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
@@ -38,23 +38,23 @@ class LearnerServiceStub:
"""
self.StreamParameters = channel.unary_stream(
'/transport.LearnerService/StreamParameters',
request_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
response_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Parameters.FromString,
request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
response_deserializer=lerobot_dot_transport_dot_services__pb2.Parameters.FromString,
_registered_method=True)
self.SendTransitions = channel.stream_unary(
'/transport.LearnerService/SendTransitions',
request_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Transition.SerializeToString,
response_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString,
request_serializer=lerobot_dot_transport_dot_services__pb2.Transition.SerializeToString,
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
_registered_method=True)
self.SendInteractions = channel.stream_unary(
'/transport.LearnerService/SendInteractions',
request_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString,
response_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString,
request_serializer=lerobot_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString,
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
_registered_method=True)
self.Ready = channel.unary_unary(
'/transport.LearnerService/Ready',
request_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
response_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString,
request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
_registered_method=True)
@@ -93,23 +93,23 @@ def add_LearnerServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'StreamParameters': grpc.unary_stream_rpc_method_handler(
servicer.StreamParameters,
request_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString,
response_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Parameters.SerializeToString,
request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
response_serializer=lerobot_dot_transport_dot_services__pb2.Parameters.SerializeToString,
),
'SendTransitions': grpc.stream_unary_rpc_method_handler(
servicer.SendTransitions,
request_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Transition.FromString,
response_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
request_deserializer=lerobot_dot_transport_dot_services__pb2.Transition.FromString,
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
),
'SendInteractions': grpc.stream_unary_rpc_method_handler(
servicer.SendInteractions,
request_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.InteractionMessage.FromString,
response_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
request_deserializer=lerobot_dot_transport_dot_services__pb2.InteractionMessage.FromString,
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
),
'Ready': grpc.unary_unary_rpc_method_handler(
servicer.Ready,
request_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString,
response_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
@@ -139,8 +139,8 @@ class LearnerService:
request,
target,
'/transport.LearnerService/StreamParameters',
src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
src_dot_lerobot_dot_transport_dot_services__pb2.Parameters.FromString,
lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
lerobot_dot_transport_dot_services__pb2.Parameters.FromString,
options,
channel_credentials,
insecure,
@@ -166,8 +166,8 @@ class LearnerService:
request_iterator,
target,
'/transport.LearnerService/SendTransitions',
src_dot_lerobot_dot_transport_dot_services__pb2.Transition.SerializeToString,
src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString,
lerobot_dot_transport_dot_services__pb2.Transition.SerializeToString,
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
@@ -193,8 +193,8 @@ class LearnerService:
request_iterator,
target,
'/transport.LearnerService/SendInteractions',
src_dot_lerobot_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString,
src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString,
lerobot_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString,
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
@@ -220,8 +220,8 @@ class LearnerService:
request,
target,
'/transport.LearnerService/Ready',
src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString,
lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
options,
channel_credentials,
insecure,