[PORT HIL-SERL] Refactor folders structure | Rebased version (#1178)
This commit is contained in:
committed by
AdilZouitine
parent
8feda920da
commit
6eeab64f8a
@@ -606,7 +606,7 @@ class SACObservationEncoder(nn.Module):
|
|||||||
|
|
||||||
Usage patterns:
|
Usage patterns:
|
||||||
- Called in select_action() with normalize=True
|
- Called in select_action() with normalize=True
|
||||||
- Called in learner_server.py's get_observation_features() to pre-compute features for all policy components
|
- Called in learner.py's get_observation_features() to pre-compute features for all policy components
|
||||||
- Called internally by forward() with normalize=False
|
- Called internally by forward() with normalize=False
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -13,9 +13,15 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
// To generate a classes for transport part (services_pb2.py and services_pb2_grpc.py) use the following command:
|
||||||
|
//
|
||||||
|
// python -m grpc_tools.protoc -I . --python_out=. --grpc_python_out=. lerobot/common/transport/services.proto
|
||||||
|
//
|
||||||
|
// The command should be launched from the root of the project.
|
||||||
|
|
||||||
syntax = "proto3";
|
syntax = "proto3";
|
||||||
|
|
||||||
package hil_serl;
|
package transport;
|
||||||
|
|
||||||
// LearnerService: the Actor calls this to push transitions.
|
// LearnerService: the Actor calls this to push transitions.
|
||||||
// The Learner implements this service.
|
// The Learner implements this service.
|
||||||
46
lerobot/common/transport/services_pb2.py
Normal file
46
lerobot/common/transport/services_pb2.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||||
|
# NO CHECKED-IN PROTOBUF GENCODE
|
||||||
|
# source: lerobot/common/transport/services.proto
|
||||||
|
# Protobuf Python Version: 5.29.0
|
||||||
|
"""Generated protocol buffer code."""
|
||||||
|
from google.protobuf import descriptor as _descriptor
|
||||||
|
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||||
|
from google.protobuf import runtime_version as _runtime_version
|
||||||
|
from google.protobuf import symbol_database as _symbol_database
|
||||||
|
from google.protobuf.internal import builder as _builder
|
||||||
|
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||||
|
_runtime_version.Domain.PUBLIC,
|
||||||
|
5,
|
||||||
|
29,
|
||||||
|
0,
|
||||||
|
'',
|
||||||
|
'lerobot/common/transport/services.proto'
|
||||||
|
)
|
||||||
|
# @@protoc_insertion_point(imports)
|
||||||
|
|
||||||
|
_sym_db = _symbol_database.Default()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\'lerobot/common/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xcc\x02\n\x0eLearnerService\x12I\n\x16SendInteractionMessage\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3')
|
||||||
|
|
||||||
|
_globals = globals()
|
||||||
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||||
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'lerobot.common.transport.services_pb2', _globals)
|
||||||
|
if not _descriptor._USE_C_DESCRIPTORS:
|
||||||
|
DESCRIPTOR._loaded_options = None
|
||||||
|
_globals['_TRANSFERSTATE']._serialized_start=305
|
||||||
|
_globals['_TRANSFERSTATE']._serialized_end=401
|
||||||
|
_globals['_TRANSITION']._serialized_start=54
|
||||||
|
_globals['_TRANSITION']._serialized_end=130
|
||||||
|
_globals['_PARAMETERS']._serialized_start=132
|
||||||
|
_globals['_PARAMETERS']._serialized_end=208
|
||||||
|
_globals['_INTERACTIONMESSAGE']._serialized_start=210
|
||||||
|
_globals['_INTERACTIONMESSAGE']._serialized_end=294
|
||||||
|
_globals['_EMPTY']._serialized_start=296
|
||||||
|
_globals['_EMPTY']._serialized_end=303
|
||||||
|
_globals['_LEARNERSERVICE']._serialized_start=404
|
||||||
|
_globals['_LEARNERSERVICE']._serialized_end=736
|
||||||
|
# @@protoc_insertion_point(module_scope)
|
||||||
@@ -3,9 +3,9 @@
|
|||||||
import grpc
|
import grpc
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import hilserl_pb2 as hilserl__pb2
|
from lerobot.common.transport import services_pb2 as lerobot_dot_common_dot_transport_dot_services__pb2
|
||||||
|
|
||||||
GRPC_GENERATED_VERSION = '1.70.0'
|
GRPC_GENERATED_VERSION = '1.71.0'
|
||||||
GRPC_VERSION = grpc.__version__
|
GRPC_VERSION = grpc.__version__
|
||||||
_version_not_supported = False
|
_version_not_supported = False
|
||||||
|
|
||||||
@@ -18,7 +18,7 @@ except ImportError:
|
|||||||
if _version_not_supported:
|
if _version_not_supported:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||||
+ f' but the generated code in hilserl_pb2_grpc.py depends on'
|
+ f' but the generated code in lerobot/common/transport/services_pb2_grpc.py depends on'
|
||||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||||
@@ -37,29 +37,29 @@ class LearnerServiceStub(object):
|
|||||||
channel: A grpc.Channel.
|
channel: A grpc.Channel.
|
||||||
"""
|
"""
|
||||||
self.SendInteractionMessage = channel.unary_unary(
|
self.SendInteractionMessage = channel.unary_unary(
|
||||||
'/hil_serl.LearnerService/SendInteractionMessage',
|
'/transport.LearnerService/SendInteractionMessage',
|
||||||
request_serializer=hilserl__pb2.InteractionMessage.SerializeToString,
|
request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString,
|
||||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
|
||||||
_registered_method=True)
|
_registered_method=True)
|
||||||
self.StreamParameters = channel.unary_stream(
|
self.StreamParameters = channel.unary_stream(
|
||||||
'/hil_serl.LearnerService/StreamParameters',
|
'/transport.LearnerService/StreamParameters',
|
||||||
request_serializer=hilserl__pb2.Empty.SerializeToString,
|
request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||||
response_deserializer=hilserl__pb2.Parameters.FromString,
|
response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Parameters.FromString,
|
||||||
_registered_method=True)
|
_registered_method=True)
|
||||||
self.SendTransitions = channel.stream_unary(
|
self.SendTransitions = channel.stream_unary(
|
||||||
'/hil_serl.LearnerService/SendTransitions',
|
'/transport.LearnerService/SendTransitions',
|
||||||
request_serializer=hilserl__pb2.Transition.SerializeToString,
|
request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Transition.SerializeToString,
|
||||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
|
||||||
_registered_method=True)
|
_registered_method=True)
|
||||||
self.SendInteractions = channel.stream_unary(
|
self.SendInteractions = channel.stream_unary(
|
||||||
'/hil_serl.LearnerService/SendInteractions',
|
'/transport.LearnerService/SendInteractions',
|
||||||
request_serializer=hilserl__pb2.InteractionMessage.SerializeToString,
|
request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString,
|
||||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
|
||||||
_registered_method=True)
|
_registered_method=True)
|
||||||
self.Ready = channel.unary_unary(
|
self.Ready = channel.unary_unary(
|
||||||
'/hil_serl.LearnerService/Ready',
|
'/transport.LearnerService/Ready',
|
||||||
request_serializer=hilserl__pb2.Empty.SerializeToString,
|
request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
|
||||||
_registered_method=True)
|
_registered_method=True)
|
||||||
|
|
||||||
|
|
||||||
@@ -104,34 +104,34 @@ def add_LearnerServiceServicer_to_server(servicer, server):
|
|||||||
rpc_method_handlers = {
|
rpc_method_handlers = {
|
||||||
'SendInteractionMessage': grpc.unary_unary_rpc_method_handler(
|
'SendInteractionMessage': grpc.unary_unary_rpc_method_handler(
|
||||||
servicer.SendInteractionMessage,
|
servicer.SendInteractionMessage,
|
||||||
request_deserializer=hilserl__pb2.InteractionMessage.FromString,
|
request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.FromString,
|
||||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||||
),
|
),
|
||||||
'StreamParameters': grpc.unary_stream_rpc_method_handler(
|
'StreamParameters': grpc.unary_stream_rpc_method_handler(
|
||||||
servicer.StreamParameters,
|
servicer.StreamParameters,
|
||||||
request_deserializer=hilserl__pb2.Empty.FromString,
|
request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
|
||||||
response_serializer=hilserl__pb2.Parameters.SerializeToString,
|
response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Parameters.SerializeToString,
|
||||||
),
|
),
|
||||||
'SendTransitions': grpc.stream_unary_rpc_method_handler(
|
'SendTransitions': grpc.stream_unary_rpc_method_handler(
|
||||||
servicer.SendTransitions,
|
servicer.SendTransitions,
|
||||||
request_deserializer=hilserl__pb2.Transition.FromString,
|
request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Transition.FromString,
|
||||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||||
),
|
),
|
||||||
'SendInteractions': grpc.stream_unary_rpc_method_handler(
|
'SendInteractions': grpc.stream_unary_rpc_method_handler(
|
||||||
servicer.SendInteractions,
|
servicer.SendInteractions,
|
||||||
request_deserializer=hilserl__pb2.InteractionMessage.FromString,
|
request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.FromString,
|
||||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||||
),
|
),
|
||||||
'Ready': grpc.unary_unary_rpc_method_handler(
|
'Ready': grpc.unary_unary_rpc_method_handler(
|
||||||
servicer.Ready,
|
servicer.Ready,
|
||||||
request_deserializer=hilserl__pb2.Empty.FromString,
|
request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
|
||||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
generic_handler = grpc.method_handlers_generic_handler(
|
generic_handler = grpc.method_handlers_generic_handler(
|
||||||
'hil_serl.LearnerService', rpc_method_handlers)
|
'transport.LearnerService', rpc_method_handlers)
|
||||||
server.add_generic_rpc_handlers((generic_handler,))
|
server.add_generic_rpc_handlers((generic_handler,))
|
||||||
server.add_registered_method_handlers('hil_serl.LearnerService', rpc_method_handlers)
|
server.add_registered_method_handlers('transport.LearnerService', rpc_method_handlers)
|
||||||
|
|
||||||
|
|
||||||
# This class is part of an EXPERIMENTAL API.
|
# This class is part of an EXPERIMENTAL API.
|
||||||
@@ -154,9 +154,9 @@ class LearnerService(object):
|
|||||||
return grpc.experimental.unary_unary(
|
return grpc.experimental.unary_unary(
|
||||||
request,
|
request,
|
||||||
target,
|
target,
|
||||||
'/hil_serl.LearnerService/SendInteractionMessage',
|
'/transport.LearnerService/SendInteractionMessage',
|
||||||
hilserl__pb2.InteractionMessage.SerializeToString,
|
lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString,
|
||||||
hilserl__pb2.Empty.FromString,
|
lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
|
||||||
options,
|
options,
|
||||||
channel_credentials,
|
channel_credentials,
|
||||||
insecure,
|
insecure,
|
||||||
@@ -181,9 +181,9 @@ class LearnerService(object):
|
|||||||
return grpc.experimental.unary_stream(
|
return grpc.experimental.unary_stream(
|
||||||
request,
|
request,
|
||||||
target,
|
target,
|
||||||
'/hil_serl.LearnerService/StreamParameters',
|
'/transport.LearnerService/StreamParameters',
|
||||||
hilserl__pb2.Empty.SerializeToString,
|
lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||||
hilserl__pb2.Parameters.FromString,
|
lerobot_dot_common_dot_transport_dot_services__pb2.Parameters.FromString,
|
||||||
options,
|
options,
|
||||||
channel_credentials,
|
channel_credentials,
|
||||||
insecure,
|
insecure,
|
||||||
@@ -208,9 +208,9 @@ class LearnerService(object):
|
|||||||
return grpc.experimental.stream_unary(
|
return grpc.experimental.stream_unary(
|
||||||
request_iterator,
|
request_iterator,
|
||||||
target,
|
target,
|
||||||
'/hil_serl.LearnerService/SendTransitions',
|
'/transport.LearnerService/SendTransitions',
|
||||||
hilserl__pb2.Transition.SerializeToString,
|
lerobot_dot_common_dot_transport_dot_services__pb2.Transition.SerializeToString,
|
||||||
hilserl__pb2.Empty.FromString,
|
lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
|
||||||
options,
|
options,
|
||||||
channel_credentials,
|
channel_credentials,
|
||||||
insecure,
|
insecure,
|
||||||
@@ -235,9 +235,9 @@ class LearnerService(object):
|
|||||||
return grpc.experimental.stream_unary(
|
return grpc.experimental.stream_unary(
|
||||||
request_iterator,
|
request_iterator,
|
||||||
target,
|
target,
|
||||||
'/hil_serl.LearnerService/SendInteractions',
|
'/transport.LearnerService/SendInteractions',
|
||||||
hilserl__pb2.InteractionMessage.SerializeToString,
|
lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString,
|
||||||
hilserl__pb2.Empty.FromString,
|
lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
|
||||||
options,
|
options,
|
||||||
channel_credentials,
|
channel_credentials,
|
||||||
insecure,
|
insecure,
|
||||||
@@ -262,9 +262,9 @@ class LearnerService(object):
|
|||||||
return grpc.experimental.unary_unary(
|
return grpc.experimental.unary_unary(
|
||||||
request,
|
request,
|
||||||
target,
|
target,
|
||||||
'/hil_serl.LearnerService/Ready',
|
'/transport.LearnerService/Ready',
|
||||||
hilserl__pb2.Empty.SerializeToString,
|
lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||||
hilserl__pb2.Empty.FromString,
|
lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
|
||||||
options,
|
options,
|
||||||
channel_credentials,
|
channel_credentials,
|
||||||
insecure,
|
insecure,
|
||||||
@@ -23,8 +23,8 @@ from typing import Any
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.scripts.server import hilserl_pb2
|
from lerobot.common.transport import services_pb2
|
||||||
from lerobot.scripts.server.utils import Transition
|
from lerobot.common.utils.transition import Transition
|
||||||
|
|
||||||
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
|
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
|
||||||
|
|
||||||
@@ -47,12 +47,12 @@ def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = ""
|
|||||||
logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with")
|
logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with")
|
||||||
|
|
||||||
while sent_bytes < size_in_bytes:
|
while sent_bytes < size_in_bytes:
|
||||||
transfer_state = hilserl_pb2.TransferState.TRANSFER_MIDDLE
|
transfer_state = services_pb2.TransferState.TRANSFER_MIDDLE
|
||||||
|
|
||||||
if sent_bytes + CHUNK_SIZE >= size_in_bytes:
|
if sent_bytes + CHUNK_SIZE >= size_in_bytes:
|
||||||
transfer_state = hilserl_pb2.TransferState.TRANSFER_END
|
transfer_state = services_pb2.TransferState.TRANSFER_END
|
||||||
elif sent_bytes == 0:
|
elif sent_bytes == 0:
|
||||||
transfer_state = hilserl_pb2.TransferState.TRANSFER_BEGIN
|
transfer_state = services_pb2.TransferState.TRANSFER_BEGIN
|
||||||
|
|
||||||
size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
|
size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
|
||||||
chunk = buffer.read(size_to_read)
|
chunk = buffer.read(size_to_read)
|
||||||
@@ -75,18 +75,18 @@ def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_p
|
|||||||
logging.info(f"{log_prefix} Shutting down receiver")
|
logging.info(f"{log_prefix} Shutting down receiver")
|
||||||
return
|
return
|
||||||
|
|
||||||
if item.transfer_state == hilserl_pb2.TransferState.TRANSFER_BEGIN:
|
if item.transfer_state == services_pb2.TransferState.TRANSFER_BEGIN:
|
||||||
bytes_buffer.seek(0)
|
bytes_buffer.seek(0)
|
||||||
bytes_buffer.truncate(0)
|
bytes_buffer.truncate(0)
|
||||||
bytes_buffer.write(item.data)
|
bytes_buffer.write(item.data)
|
||||||
logging.debug(f"{log_prefix} Received data at step 0")
|
logging.debug(f"{log_prefix} Received data at step 0")
|
||||||
step = 0
|
step = 0
|
||||||
continue
|
continue
|
||||||
elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_MIDDLE:
|
elif item.transfer_state == services_pb2.TransferState.TRANSFER_MIDDLE:
|
||||||
bytes_buffer.write(item.data)
|
bytes_buffer.write(item.data)
|
||||||
step += 1
|
step += 1
|
||||||
logging.debug(f"{log_prefix} Received data at step {step}")
|
logging.debug(f"{log_prefix} Received data at step {step}")
|
||||||
elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_END:
|
elif item.transfer_state == services_pb2.TransferState.TRANSFER_END:
|
||||||
bytes_buffer.write(item.data)
|
bytes_buffer.write(item.data)
|
||||||
logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
|
logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
|
||||||
|
|
||||||
@@ -23,7 +23,7 @@ import torch.nn.functional as F # noqa: N812
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.scripts.server.utils import Transition
|
from lerobot.common.utils.transition import Transition
|
||||||
|
|
||||||
|
|
||||||
class BatchTransition(TypedDict):
|
class BatchTransition(TypedDict):
|
||||||
53
lerobot/common/utils/process.py
Normal file
53
lerobot/common/utils/process.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
|
||||||
|
shutdown_event_counter = 0
|
||||||
|
|
||||||
|
|
||||||
|
def setup_process_handlers(use_threads: bool) -> any:
|
||||||
|
if use_threads:
|
||||||
|
from threading import Event
|
||||||
|
else:
|
||||||
|
from multiprocessing import Event
|
||||||
|
|
||||||
|
shutdown_event = Event()
|
||||||
|
|
||||||
|
# Define signal handler
|
||||||
|
def signal_handler(signum, frame):
|
||||||
|
logging.info("Shutdown signal received. Cleaning up...")
|
||||||
|
shutdown_event.set()
|
||||||
|
global shutdown_event_counter
|
||||||
|
shutdown_event_counter += 1
|
||||||
|
|
||||||
|
if shutdown_event_counter > 1:
|
||||||
|
logging.info("Force shutdown")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler) # Termination request (kill)
|
||||||
|
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed/Hangup
|
||||||
|
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
|
||||||
|
|
||||||
|
def signal_handler(signum, frame):
|
||||||
|
logging.info("Shutdown signal received. Cleaning up...")
|
||||||
|
shutdown_event.set()
|
||||||
|
|
||||||
|
return shutdown_event
|
||||||
35
lerobot/common/utils/queue.py
Normal file
35
lerobot/common/utils/queue.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from queue import Empty, Queue
|
||||||
|
|
||||||
|
|
||||||
|
def get_last_item_from_queue(queue: Queue):
|
||||||
|
item = queue.get()
|
||||||
|
counter = 1
|
||||||
|
|
||||||
|
# Drain queue and keep only the most recent parameters
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
item = queue.get_nowait()
|
||||||
|
counter += 1
|
||||||
|
except Empty:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logging.debug(f"Drained {counter} items from queue")
|
||||||
|
|
||||||
|
return item
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team.
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
# All rights reserved.
|
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -15,64 +14,9 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
|
||||||
import signal
|
|
||||||
import sys
|
|
||||||
from queue import Empty
|
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.multiprocessing import Queue
|
|
||||||
|
|
||||||
shutdown_event_counter = 0
|
|
||||||
|
|
||||||
|
|
||||||
def setup_process_handlers(use_threads: bool) -> any:
|
|
||||||
if use_threads:
|
|
||||||
from threading import Event
|
|
||||||
else:
|
|
||||||
from multiprocessing import Event
|
|
||||||
|
|
||||||
shutdown_event = Event()
|
|
||||||
|
|
||||||
# Define signal handler
|
|
||||||
def signal_handler(signum, frame):
|
|
||||||
logging.info("Shutdown signal received. Cleaning up...")
|
|
||||||
shutdown_event.set()
|
|
||||||
global shutdown_event_counter
|
|
||||||
shutdown_event_counter += 1
|
|
||||||
|
|
||||||
if shutdown_event_counter > 1:
|
|
||||||
logging.info("Force shutdown")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler) # Termination request (kill)
|
|
||||||
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed/Hangup
|
|
||||||
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
|
|
||||||
|
|
||||||
def signal_handler(signum, frame):
|
|
||||||
logging.info("Shutdown signal received. Cleaning up...")
|
|
||||||
shutdown_event.set()
|
|
||||||
|
|
||||||
return shutdown_event
|
|
||||||
|
|
||||||
|
|
||||||
def get_last_item_from_queue(queue: Queue):
|
|
||||||
item = queue.get()
|
|
||||||
counter = 1
|
|
||||||
|
|
||||||
# Drain queue and keep only the most recent parameters
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
item = queue.get_nowait()
|
|
||||||
counter += 1
|
|
||||||
except Empty:
|
|
||||||
pass
|
|
||||||
|
|
||||||
logging.debug(f"Drained {counter} items from queue")
|
|
||||||
|
|
||||||
return item
|
|
||||||
|
|
||||||
|
|
||||||
class Transition(TypedDict):
|
class Transition(TypedDict):
|
||||||
@@ -24,12 +24,12 @@ Examples of usage:
|
|||||||
|
|
||||||
- Start an actor server for real robot training with human-in-the-loop intervention:
|
- Start an actor server for real robot training with human-in-the-loop intervention:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/server/actor_server.py --config_path lerobot/configs/train_config_hilserl_so100.json
|
python lerobot/scripts/rl/actor.py --config_path lerobot/configs/train_config_hilserl_so100.json
|
||||||
```
|
```
|
||||||
|
|
||||||
- Run with a specific robot type for a pick and place task:
|
- Run with a specific robot type for a pick and place task:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/server/actor_server.py \
|
python lerobot/scripts/rl/actor.py \
|
||||||
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||||
--robot.type=so100 \
|
--robot.type=so100 \
|
||||||
--task=pick_and_place
|
--task=pick_and_place
|
||||||
@@ -37,7 +37,7 @@ python lerobot/scripts/server/actor_server.py \
|
|||||||
|
|
||||||
- Set a custom workspace bound for the robot's end-effector:
|
- Set a custom workspace bound for the robot's end-effector:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/server/actor_server.py \
|
python lerobot/scripts/rl/actor.py \
|
||||||
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||||
--env.ee_action_space_params.bounds.max="[0.24, 0.20, 0.10]" \
|
--env.ee_action_space_params.bounds.max="[0.24, 0.20, 0.10]" \
|
||||||
--env.ee_action_space_params.bounds.min="[0.16, -0.08, 0.03]"
|
--env.ee_action_space_params.bounds.min="[0.16, -0.08, 0.03]"
|
||||||
@@ -45,7 +45,7 @@ python lerobot/scripts/server/actor_server.py \
|
|||||||
|
|
||||||
- Run with specific camera crop parameters:
|
- Run with specific camera crop parameters:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/server/actor_server.py \
|
python lerobot/scripts/rl/actor.py \
|
||||||
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||||
--env.crop_params_dict="{'observation.images.side': [180, 207, 180, 200], 'observation.images.front': [180, 250, 120, 150]}"
|
--env.crop_params_dict="{'observation.images.side': [180, 207, 180, 200], 'observation.images.front': [180, 250, 120, 150]}"
|
||||||
```
|
```
|
||||||
@@ -85,8 +85,23 @@ from lerobot.common.policies.factory import make_policy
|
|||||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||||
from lerobot.common.robots import so100_follower_end_effector # noqa: F401
|
from lerobot.common.robots import so100_follower_end_effector # noqa: F401
|
||||||
from lerobot.common.teleoperators import gamepad, so100_leader # noqa: F401
|
from lerobot.common.teleoperators import gamepad, so100_leader # noqa: F401
|
||||||
|
from lerobot.common.transport import services_pb2, services_pb2_grpc
|
||||||
|
from lerobot.common.transport.utils import (
|
||||||
|
bytes_to_state_dict,
|
||||||
|
python_object_to_bytes,
|
||||||
|
receive_bytes_in_chunks,
|
||||||
|
send_bytes_in_chunks,
|
||||||
|
transitions_to_bytes,
|
||||||
|
)
|
||||||
|
from lerobot.common.utils.process import setup_process_handlers
|
||||||
|
from lerobot.common.utils.queue import get_last_item_from_queue
|
||||||
from lerobot.common.utils.random_utils import set_seed
|
from lerobot.common.utils.random_utils import set_seed
|
||||||
from lerobot.common.utils.robot_utils import busy_wait
|
from lerobot.common.utils.robot_utils import busy_wait
|
||||||
|
from lerobot.common.utils.transition import (
|
||||||
|
Transition,
|
||||||
|
move_state_dict_to_device,
|
||||||
|
move_transition_to_device,
|
||||||
|
)
|
||||||
from lerobot.common.utils.utils import (
|
from lerobot.common.utils.utils import (
|
||||||
TimerManager,
|
TimerManager,
|
||||||
get_safe_torch_device,
|
get_safe_torch_device,
|
||||||
@@ -94,22 +109,8 @@ from lerobot.common.utils.utils import (
|
|||||||
)
|
)
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service
|
from lerobot.scripts.rl import learner_service
|
||||||
from lerobot.scripts.server.buffer import Transition
|
from lerobot.scripts.rl.gym_manipulator import make_robot_env
|
||||||
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
|
||||||
from lerobot.scripts.server.network_utils import (
|
|
||||||
bytes_to_state_dict,
|
|
||||||
python_object_to_bytes,
|
|
||||||
receive_bytes_in_chunks,
|
|
||||||
send_bytes_in_chunks,
|
|
||||||
transitions_to_bytes,
|
|
||||||
)
|
|
||||||
from lerobot.scripts.server.utils import (
|
|
||||||
get_last_item_from_queue,
|
|
||||||
move_state_dict_to_device,
|
|
||||||
move_transition_to_device,
|
|
||||||
setup_process_handlers,
|
|
||||||
)
|
|
||||||
|
|
||||||
ACTOR_SHUTDOWN_TIMEOUT = 30
|
ACTOR_SHUTDOWN_TIMEOUT = 30
|
||||||
|
|
||||||
@@ -386,14 +387,14 @@ def act_with_policy(
|
|||||||
|
|
||||||
|
|
||||||
def establish_learner_connection(
|
def establish_learner_connection(
|
||||||
stub: hilserl_pb2_grpc.LearnerServiceStub,
|
stub: services_pb2_grpc.LearnerServiceStub,
|
||||||
shutdown_event: Event, # type: ignore
|
shutdown_event: Event, # type: ignore
|
||||||
attempts: int = 30,
|
attempts: int = 30,
|
||||||
):
|
):
|
||||||
"""Establish a connection with the learner.
|
"""Establish a connection with the learner.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
stub (hilserl_pb2_grpc.LearnerServiceStub): The stub to use for the connection.
|
stub (services_pb2_grpc.LearnerServiceStub): The stub to use for the connection.
|
||||||
shutdown_event (Event): The event to check if the connection should be established.
|
shutdown_event (Event): The event to check if the connection should be established.
|
||||||
attempts (int): The number of attempts to establish the connection.
|
attempts (int): The number of attempts to establish the connection.
|
||||||
Returns:
|
Returns:
|
||||||
@@ -407,7 +408,7 @@ def establish_learner_connection(
|
|||||||
# Force a connection attempt and check state
|
# Force a connection attempt and check state
|
||||||
try:
|
try:
|
||||||
logging.info("[ACTOR] Send ready message to Learner")
|
logging.info("[ACTOR] Send ready message to Learner")
|
||||||
if stub.Ready(hilserl_pb2.Empty()) == hilserl_pb2.Empty():
|
if stub.Ready(services_pb2.Empty()) == services_pb2.Empty():
|
||||||
return True
|
return True
|
||||||
except grpc.RpcError as e:
|
except grpc.RpcError as e:
|
||||||
logging.error(f"[ACTOR] Waiting for Learner to be ready... {e}")
|
logging.error(f"[ACTOR] Waiting for Learner to be ready... {e}")
|
||||||
@@ -419,7 +420,7 @@ def establish_learner_connection(
|
|||||||
def learner_service_client(
|
def learner_service_client(
|
||||||
host: str = "127.0.0.1",
|
host: str = "127.0.0.1",
|
||||||
port: int = 50051,
|
port: int = 50051,
|
||||||
) -> tuple[hilserl_pb2_grpc.LearnerServiceStub, grpc.Channel]:
|
) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]:
|
||||||
import json
|
import json
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -458,7 +459,7 @@ def learner_service_client(
|
|||||||
("grpc.service_config", service_config_json),
|
("grpc.service_config", service_config_json),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
stub = hilserl_pb2_grpc.LearnerServiceStub(channel)
|
stub = services_pb2_grpc.LearnerServiceStub(channel)
|
||||||
logging.info("[ACTOR] Learner service client created")
|
logging.info("[ACTOR] Learner service client created")
|
||||||
return stub, channel
|
return stub, channel
|
||||||
|
|
||||||
@@ -467,7 +468,7 @@ def receive_policy(
|
|||||||
cfg: TrainPipelineConfig,
|
cfg: TrainPipelineConfig,
|
||||||
parameters_queue: Queue,
|
parameters_queue: Queue,
|
||||||
shutdown_event: Event, # type: ignore
|
shutdown_event: Event, # type: ignore
|
||||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
||||||
grpc_channel: grpc.Channel | None = None,
|
grpc_channel: grpc.Channel | None = None,
|
||||||
):
|
):
|
||||||
"""Receive parameters from the learner.
|
"""Receive parameters from the learner.
|
||||||
@@ -499,7 +500,7 @@ def receive_policy(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
iterator = learner_client.StreamParameters(hilserl_pb2.Empty())
|
iterator = learner_client.StreamParameters(services_pb2.Empty())
|
||||||
receive_bytes_in_chunks(
|
receive_bytes_in_chunks(
|
||||||
iterator,
|
iterator,
|
||||||
parameters_queue,
|
parameters_queue,
|
||||||
@@ -519,9 +520,9 @@ def send_transitions(
|
|||||||
cfg: TrainPipelineConfig,
|
cfg: TrainPipelineConfig,
|
||||||
transitions_queue: Queue,
|
transitions_queue: Queue,
|
||||||
shutdown_event: any, # Event,
|
shutdown_event: any, # Event,
|
||||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
||||||
grpc_channel: grpc.Channel | None = None,
|
grpc_channel: grpc.Channel | None = None,
|
||||||
) -> hilserl_pb2.Empty:
|
) -> services_pb2.Empty:
|
||||||
"""
|
"""
|
||||||
Sends transitions to the learner.
|
Sends transitions to the learner.
|
||||||
|
|
||||||
@@ -530,7 +531,7 @@ def send_transitions(
|
|||||||
- Transition Data:
|
- Transition Data:
|
||||||
- A batch of transitions (observation, action, reward, next observation) is collected.
|
- A batch of transitions (observation, action, reward, next observation) is collected.
|
||||||
- Transitions are moved to the CPU and serialized using PyTorch.
|
- Transitions are moved to the CPU and serialized using PyTorch.
|
||||||
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
|
- The serialized data is wrapped in a `services_pb2.Transition` message and sent to the learner.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not use_threads(cfg):
|
if not use_threads(cfg):
|
||||||
@@ -569,9 +570,9 @@ def send_interactions(
|
|||||||
cfg: TrainPipelineConfig,
|
cfg: TrainPipelineConfig,
|
||||||
interactions_queue: Queue,
|
interactions_queue: Queue,
|
||||||
shutdown_event: Event, # type: ignore
|
shutdown_event: Event, # type: ignore
|
||||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
||||||
grpc_channel: grpc.Channel | None = None,
|
grpc_channel: grpc.Channel | None = None,
|
||||||
) -> hilserl_pb2.Empty:
|
) -> services_pb2.Empty:
|
||||||
"""
|
"""
|
||||||
Sends interactions to the learner.
|
Sends interactions to the learner.
|
||||||
|
|
||||||
@@ -614,7 +615,7 @@ def send_interactions(
|
|||||||
logging.info("[ACTOR] Interactions process stopped")
|
logging.info("[ACTOR] Interactions process stopped")
|
||||||
|
|
||||||
|
|
||||||
def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilserl_pb2.Empty: # type: ignore
|
def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> services_pb2.Empty: # type: ignore
|
||||||
while not shutdown_event.is_set():
|
while not shutdown_event.is_set():
|
||||||
try:
|
try:
|
||||||
message = transitions_queue.get(block=True, timeout=5)
|
message = transitions_queue.get(block=True, timeout=5)
|
||||||
@@ -623,16 +624,16 @@ def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilse
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
yield from send_bytes_in_chunks(
|
yield from send_bytes_in_chunks(
|
||||||
message, hilserl_pb2.Transition, log_prefix="[ACTOR] Send transitions"
|
message, services_pb2.Transition, log_prefix="[ACTOR] Send transitions"
|
||||||
)
|
)
|
||||||
|
|
||||||
return hilserl_pb2.Empty()
|
return services_pb2.Empty()
|
||||||
|
|
||||||
|
|
||||||
def interactions_stream(
|
def interactions_stream(
|
||||||
shutdown_event: Event, # type: ignore
|
shutdown_event: Event, # type: ignore
|
||||||
interactions_queue: Queue,
|
interactions_queue: Queue,
|
||||||
) -> hilserl_pb2.Empty:
|
) -> services_pb2.Empty:
|
||||||
while not shutdown_event.is_set():
|
while not shutdown_event.is_set():
|
||||||
try:
|
try:
|
||||||
message = interactions_queue.get(block=True, timeout=5)
|
message = interactions_queue.get(block=True, timeout=5)
|
||||||
@@ -642,11 +643,11 @@ def interactions_stream(
|
|||||||
|
|
||||||
yield from send_bytes_in_chunks(
|
yield from send_bytes_in_chunks(
|
||||||
message,
|
message,
|
||||||
hilserl_pb2.InteractionMessage,
|
services_pb2.InteractionMessage,
|
||||||
log_prefix="[ACTOR] Send interactions",
|
log_prefix="[ACTOR] Send interactions",
|
||||||
)
|
)
|
||||||
|
|
||||||
return hilserl_pb2.Empty()
|
return services_pb2.Empty()
|
||||||
|
|
||||||
|
|
||||||
#################################################
|
#################################################
|
||||||
@@ -25,12 +25,12 @@ Examples of usage:
|
|||||||
|
|
||||||
- Start a learner server for training:
|
- Start a learner server for training:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/server/learner_server.py --config_path lerobot/configs/train_config_hilserl_so100.json
|
python lerobot/scripts/rl/learner.py --config_path lerobot/configs/train_config_hilserl_so100.json
|
||||||
```
|
```
|
||||||
|
|
||||||
- Run with specific SAC hyperparameters:
|
- Run with specific SAC hyperparameters:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/server/learner_server.py \
|
python lerobot/scripts/rl/learner.py \
|
||||||
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||||
--learner.sac.alpha=0.1 \
|
--learner.sac.alpha=0.1 \
|
||||||
--learner.sac.gamma=0.99
|
--learner.sac.gamma=0.99
|
||||||
@@ -38,7 +38,7 @@ python lerobot/scripts/server/learner_server.py \
|
|||||||
|
|
||||||
- Run with a specific dataset and wandb logging:
|
- Run with a specific dataset and wandb logging:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/server/learner_server.py \
|
python lerobot/scripts/rl/learner.py \
|
||||||
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||||
--dataset.repo_id=username/pick_lift_cube \
|
--dataset.repo_id=username/pick_lift_cube \
|
||||||
--wandb.enable=true \
|
--wandb.enable=true \
|
||||||
@@ -47,14 +47,14 @@ python lerobot/scripts/server/learner_server.py \
|
|||||||
|
|
||||||
- Run with a pretrained policy for fine-tuning:
|
- Run with a pretrained policy for fine-tuning:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/server/learner_server.py \
|
python lerobot/scripts/rl/learner.py \
|
||||||
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||||
--pretrained_policy_name_or_path=outputs/previous_training/checkpoints/080000/pretrained_model
|
--pretrained_policy_name_or_path=outputs/previous_training/checkpoints/080000/pretrained_model
|
||||||
```
|
```
|
||||||
|
|
||||||
- Run with a reward classifier model:
|
- Run with a reward classifier model:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/server/learner_server.py \
|
python lerobot/scripts/rl/learner.py \
|
||||||
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||||
--reward_classifier_pretrained_path=outputs/reward_model/best_model
|
--reward_classifier_pretrained_path=outputs/reward_model/best_model
|
||||||
```
|
```
|
||||||
@@ -84,7 +84,6 @@ from pathlib import Path
|
|||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
import hilserl_pb2_grpc # type: ignore
|
|
||||||
import torch
|
import torch
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -103,6 +102,14 @@ from lerobot.common.policies.factory import make_policy
|
|||||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||||
from lerobot.common.robots import so100_follower_end_effector # noqa: F401
|
from lerobot.common.robots import so100_follower_end_effector # noqa: F401
|
||||||
from lerobot.common.teleoperators import gamepad, so100_leader # noqa: F401
|
from lerobot.common.teleoperators import gamepad, so100_leader # noqa: F401
|
||||||
|
from lerobot.common.transport import services_pb2_grpc
|
||||||
|
from lerobot.common.transport.utils import (
|
||||||
|
bytes_to_python_object,
|
||||||
|
bytes_to_transitions,
|
||||||
|
state_to_bytes,
|
||||||
|
)
|
||||||
|
from lerobot.common.utils.buffer import ReplayBuffer, concatenate_batch_transitions
|
||||||
|
from lerobot.common.utils.process import setup_process_handlers
|
||||||
from lerobot.common.utils.random_utils import set_seed
|
from lerobot.common.utils.random_utils import set_seed
|
||||||
from lerobot.common.utils.train_utils import (
|
from lerobot.common.utils.train_utils import (
|
||||||
get_step_checkpoint_dir,
|
get_step_checkpoint_dir,
|
||||||
@@ -112,6 +119,7 @@ from lerobot.common.utils.train_utils import (
|
|||||||
from lerobot.common.utils.train_utils import (
|
from lerobot.common.utils.train_utils import (
|
||||||
load_training_state as utils_load_training_state,
|
load_training_state as utils_load_training_state,
|
||||||
)
|
)
|
||||||
|
from lerobot.common.utils.transition import move_state_dict_to_device, move_transition_to_device
|
||||||
from lerobot.common.utils.utils import (
|
from lerobot.common.utils.utils import (
|
||||||
format_big_number,
|
format_big_number,
|
||||||
get_safe_torch_device,
|
get_safe_torch_device,
|
||||||
@@ -120,18 +128,7 @@ from lerobot.common.utils.utils import (
|
|||||||
from lerobot.common.utils.wandb_utils import WandBLogger
|
from lerobot.common.utils.wandb_utils import WandBLogger
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from lerobot.scripts.server import learner_service
|
from lerobot.scripts.rl import learner_service
|
||||||
from lerobot.scripts.server.buffer import ReplayBuffer, concatenate_batch_transitions
|
|
||||||
from lerobot.scripts.server.network_utils import (
|
|
||||||
bytes_to_python_object,
|
|
||||||
bytes_to_transitions,
|
|
||||||
state_to_bytes,
|
|
||||||
)
|
|
||||||
from lerobot.scripts.server.utils import (
|
|
||||||
move_state_dict_to_device,
|
|
||||||
move_transition_to_device,
|
|
||||||
setup_process_handlers,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG_PREFIX = "[LEARNER]"
|
LOG_PREFIX = "[LEARNER]"
|
||||||
|
|
||||||
@@ -244,7 +241,7 @@ def start_learner_threads(
|
|||||||
concurrency_entity = Process
|
concurrency_entity = Process
|
||||||
|
|
||||||
communication_process = concurrency_entity(
|
communication_process = concurrency_entity(
|
||||||
target=start_learner_server,
|
target=start_learner,
|
||||||
args=(
|
args=(
|
||||||
parameters_queue,
|
parameters_queue,
|
||||||
transition_queue,
|
transition_queue,
|
||||||
@@ -643,7 +640,7 @@ def add_actor_information_and_train(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def start_learner_server(
|
def start_learner(
|
||||||
parameters_queue: Queue,
|
parameters_queue: Queue,
|
||||||
transition_queue: Queue,
|
transition_queue: Queue,
|
||||||
interaction_message_queue: Queue,
|
interaction_message_queue: Queue,
|
||||||
@@ -666,7 +663,7 @@ def start_learner_server(
|
|||||||
# Create a process-specific log file
|
# Create a process-specific log file
|
||||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
log_file = os.path.join(log_dir, f"learner_server_process_{os.getpid()}.log")
|
log_file = os.path.join(log_dir, f"learner_process_{os.getpid()}.log")
|
||||||
|
|
||||||
# Initialize logging with explicit log file
|
# Initialize logging with explicit log file
|
||||||
init_logging(log_file=log_file, display_pid=True)
|
init_logging(log_file=log_file, display_pid=True)
|
||||||
@@ -693,7 +690,7 @@ def start_learner_server(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
hilserl_pb2_grpc.add_LearnerServiceServicer_to_server(
|
services_pb2_grpc.add_LearnerServiceServicer_to_server(
|
||||||
service,
|
service,
|
||||||
server,
|
server,
|
||||||
)
|
)
|
||||||
@@ -1,17 +1,21 @@
|
|||||||
import logging
|
import logging
|
||||||
from multiprocessing import Event, Queue
|
from multiprocessing import Event, Queue
|
||||||
|
|
||||||
import hilserl_pb2 # type: ignore
|
from lerobot.common.transport import services_pb2, services_pb2_grpc
|
||||||
import hilserl_pb2_grpc # type: ignore
|
from lerobot.common.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
|
||||||
|
|
||||||
from lerobot.scripts.server.network_utils import receive_bytes_in_chunks, send_bytes_in_chunks
|
|
||||||
|
|
||||||
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
|
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
|
||||||
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
|
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
|
||||||
SHUTDOWN_TIMEOUT = 10
|
SHUTDOWN_TIMEOUT = 10
|
||||||
|
|
||||||
|
|
||||||
class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
|
class LearnerService(services_pb2_grpc.LearnerServiceServicer):
|
||||||
|
"""
|
||||||
|
Implementation of the LearnerService gRPC service
|
||||||
|
This service is used to send parameters to the Actor and receive transitions and interactions from the Actor
|
||||||
|
check transport.proto for the gRPC service definition
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
shutdown_event: Event, # type: ignore
|
shutdown_event: Event, # type: ignore
|
||||||
@@ -36,7 +40,7 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
|
|||||||
|
|
||||||
yield from send_bytes_in_chunks(
|
yield from send_bytes_in_chunks(
|
||||||
buffer,
|
buffer,
|
||||||
hilserl_pb2.Parameters,
|
services_pb2.Parameters,
|
||||||
log_prefix="[LEARNER] Sending parameters",
|
log_prefix="[LEARNER] Sending parameters",
|
||||||
silent=True,
|
silent=True,
|
||||||
)
|
)
|
||||||
@@ -46,7 +50,7 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
|
|||||||
self.shutdown_event.wait(self.seconds_between_pushes)
|
self.shutdown_event.wait(self.seconds_between_pushes)
|
||||||
|
|
||||||
logging.info("[LEARNER] Stream parameters finished")
|
logging.info("[LEARNER] Stream parameters finished")
|
||||||
return hilserl_pb2.Empty()
|
return services_pb2.Empty()
|
||||||
|
|
||||||
def SendTransitions(self, request_iterator, _context): # noqa: N802
|
def SendTransitions(self, request_iterator, _context): # noqa: N802
|
||||||
# TODO: authorize the request
|
# TODO: authorize the request
|
||||||
@@ -60,7 +64,7 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
logging.debug("[LEARNER] Finished receiving transitions")
|
logging.debug("[LEARNER] Finished receiving transitions")
|
||||||
return hilserl_pb2.Empty()
|
return services_pb2.Empty()
|
||||||
|
|
||||||
def SendInteractions(self, request_iterator, _context): # noqa: N802
|
def SendInteractions(self, request_iterator, _context): # noqa: N802
|
||||||
# TODO: authorize the request
|
# TODO: authorize the request
|
||||||
@@ -74,7 +78,7 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
logging.debug("[LEARNER] Finished receiving interactions")
|
logging.debug("[LEARNER] Finished receiving interactions")
|
||||||
return hilserl_pb2.Empty()
|
return services_pb2.Empty()
|
||||||
|
|
||||||
def Ready(self, request, context): # noqa: N802
|
def Ready(self, request, context): # noqa: N802
|
||||||
return hilserl_pb2.Empty()
|
return services_pb2.Empty()
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
||||||
# NO CHECKED-IN PROTOBUF GENCODE
|
|
||||||
# source: hilserl.proto
|
|
||||||
# Protobuf Python Version: 5.29.0
|
|
||||||
"""Generated protocol buffer code."""
|
|
||||||
from google.protobuf import descriptor as _descriptor
|
|
||||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
||||||
from google.protobuf import runtime_version as _runtime_version
|
|
||||||
from google.protobuf import symbol_database as _symbol_database
|
|
||||||
from google.protobuf.internal import builder as _builder
|
|
||||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
|
||||||
_runtime_version.Domain.PUBLIC,
|
|
||||||
5,
|
|
||||||
29,
|
|
||||||
0,
|
|
||||||
'',
|
|
||||||
'hilserl.proto'
|
|
||||||
)
|
|
||||||
# @@protoc_insertion_point(imports)
|
|
||||||
|
|
||||||
_sym_db = _symbol_database.Default()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"K\n\nTransition\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"K\n\nParameters\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"S\n\x12InteractionMessage\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xc2\x02\n\x0eLearnerService\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty\x12;\n\x10StreamParameters\x12\x0f.hil_serl.Empty\x1a\x14.hil_serl.Parameters0\x01\x12:\n\x0fSendTransitions\x12\x14.hil_serl.Transition\x1a\x0f.hil_serl.Empty(\x01\x12\x43\n\x10SendInteractions\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty(\x01\x12)\n\x05Ready\x12\x0f.hil_serl.Empty\x1a\x0f.hil_serl.Emptyb\x06proto3')
|
|
||||||
|
|
||||||
_globals = globals()
|
|
||||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
||||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'hilserl_pb2', _globals)
|
|
||||||
if not _descriptor._USE_C_DESCRIPTORS:
|
|
||||||
DESCRIPTOR._loaded_options = None
|
|
||||||
_globals['_TRANSFERSTATE']._serialized_start=275
|
|
||||||
_globals['_TRANSFERSTATE']._serialized_end=371
|
|
||||||
_globals['_TRANSITION']._serialized_start=27
|
|
||||||
_globals['_TRANSITION']._serialized_end=102
|
|
||||||
_globals['_PARAMETERS']._serialized_start=104
|
|
||||||
_globals['_PARAMETERS']._serialized_end=179
|
|
||||||
_globals['_INTERACTIONMESSAGE']._serialized_start=181
|
|
||||||
_globals['_INTERACTIONMESSAGE']._serialized_end=264
|
|
||||||
_globals['_EMPTY']._serialized_start=266
|
|
||||||
_globals['_EMPTY']._serialized_end=273
|
|
||||||
_globals['_LEARNERSERVICE']._serialized_start=374
|
|
||||||
_globals['_LEARNERSERVICE']._serialized_end=696
|
|
||||||
# @@protoc_insertion_point(module_scope)
|
|
||||||
@@ -1,546 +0,0 @@
|
|||||||
# ruff: noqa: N806, N815, N803
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from scipy.spatial.transform import Rotation
|
|
||||||
|
|
||||||
|
|
||||||
def skew_symmetric(w):
|
|
||||||
"""Creates the skew-symmetric matrix from a 3D vector."""
|
|
||||||
return np.array([[0, -w[2], w[1]], [w[2], 0, -w[0]], [-w[1], w[0], 0]])
|
|
||||||
|
|
||||||
|
|
||||||
def rodrigues_rotation(w, theta):
|
|
||||||
"""Computes the rotation matrix using Rodrigues' formula."""
|
|
||||||
w_hat = skew_symmetric(w)
|
|
||||||
return np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
|
|
||||||
|
|
||||||
|
|
||||||
def screw_axis_to_transform(S, theta):
|
|
||||||
"""Converts a screw axis to a 4x4 transformation matrix."""
|
|
||||||
S_w = S[:3]
|
|
||||||
S_v = S[3:]
|
|
||||||
if np.allclose(S_w, 0) and np.linalg.norm(S_v) == 1: # Pure translation
|
|
||||||
T = np.eye(4)
|
|
||||||
T[:3, 3] = S_v * theta
|
|
||||||
elif np.linalg.norm(S_w) == 1: # Rotation and translation
|
|
||||||
w_hat = skew_symmetric(S_w)
|
|
||||||
R = np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
|
|
||||||
t = (np.eye(3) * theta + (1 - np.cos(theta)) * w_hat + (theta - np.sin(theta)) * w_hat @ w_hat) @ S_v
|
|
||||||
T = np.eye(4)
|
|
||||||
T[:3, :3] = R
|
|
||||||
T[:3, 3] = t
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid screw axis parameters")
|
|
||||||
return T
|
|
||||||
|
|
||||||
|
|
||||||
def pose_difference_se3(pose1, pose2):
|
|
||||||
"""
|
|
||||||
Calculates the SE(3) difference between two 4x4 homogeneous transformation matrices.
|
|
||||||
SE(3) (Special Euclidean Group) represents rigid body transformations in 3D space, combining rotation (SO(3)) and translation.
|
|
||||||
Each 4x4 matrix has the following structure, a 3x3 rotation matrix in the top-left and a 3x1 translation vector in the top-right:
|
|
||||||
|
|
||||||
[R11 R12 R13 tx]
|
|
||||||
[R21 R22 R23 ty]
|
|
||||||
[R31 R32 R33 tz]
|
|
||||||
[ 0 0 0 1]
|
|
||||||
|
|
||||||
where Rij is the 3x3 rotation matrix and [tx,ty,tz] is the translation vector.
|
|
||||||
|
|
||||||
pose1 - pose2
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pose1: A 4x4 numpy array representing the first pose.
|
|
||||||
pose2: A 4x4 numpy array representing the second pose.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple (translation_diff, rotation_diff) where:
|
|
||||||
- translation_diff is a 3x1 numpy array representing the translational difference.
|
|
||||||
- rotation_diff is a 3x1 numpy array representing the rotational difference in axis-angle representation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Extract rotation matrices from poses
|
|
||||||
R1 = pose1[:3, :3]
|
|
||||||
R2 = pose2[:3, :3]
|
|
||||||
|
|
||||||
# Calculate translational difference
|
|
||||||
translation_diff = pose1[:3, 3] - pose2[:3, 3]
|
|
||||||
|
|
||||||
# Calculate rotational difference using scipy's Rotation library
|
|
||||||
R_diff = Rotation.from_matrix(R1 @ R2.T)
|
|
||||||
rotation_diff = R_diff.as_rotvec() # Convert to axis-angle representation
|
|
||||||
|
|
||||||
return np.concatenate([translation_diff, rotation_diff])
|
|
||||||
|
|
||||||
|
|
||||||
def se3_error(target_pose, current_pose):
|
|
||||||
pos_error = target_pose[:3, 3] - current_pose[:3, 3]
|
|
||||||
R_target = target_pose[:3, :3]
|
|
||||||
R_current = current_pose[:3, :3]
|
|
||||||
R_error = R_target @ R_current.T
|
|
||||||
rot_error = Rotation.from_matrix(R_error).as_rotvec()
|
|
||||||
return np.concatenate([pos_error, rot_error])
|
|
||||||
|
|
||||||
|
|
||||||
class RobotKinematics:
|
|
||||||
"""Robot kinematics class supporting multiple robot models."""
|
|
||||||
|
|
||||||
# Robot measurements dictionary
|
|
||||||
ROBOT_MEASUREMENTS = {
|
|
||||||
"koch": {
|
|
||||||
"gripper": [0.239, -0.001, 0.024],
|
|
||||||
"wrist": [0.209, 0, 0.024],
|
|
||||||
"forearm": [0.108, 0, 0.02],
|
|
||||||
"humerus": [0, 0, 0.036],
|
|
||||||
"shoulder": [0, 0, 0],
|
|
||||||
"base": [0, 0, 0.02],
|
|
||||||
},
|
|
||||||
"so100": {
|
|
||||||
"gripper": [0.320, 0, 0.050],
|
|
||||||
"wrist": [0.278, 0, 0.050],
|
|
||||||
"forearm": [0.143, 0, 0.044],
|
|
||||||
"humerus": [0.031, 0, 0.072],
|
|
||||||
"shoulder": [0, 0, 0],
|
|
||||||
"base": [0, 0, 0.02],
|
|
||||||
},
|
|
||||||
"moss": {
|
|
||||||
"gripper": [0.246, 0.013, 0.111],
|
|
||||||
"wrist": [0.245, 0.002, 0.064],
|
|
||||||
"forearm": [0.122, 0, 0.064],
|
|
||||||
"humerus": [0.001, 0.001, 0.063],
|
|
||||||
"shoulder": [0, 0, 0],
|
|
||||||
"base": [0, 0, 0.02],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, robot_type="so100"):
|
|
||||||
"""Initialize kinematics for the specified robot type.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
robot_type: String specifying the robot model ("koch", "so100", or "moss")
|
|
||||||
"""
|
|
||||||
if robot_type not in self.ROBOT_MEASUREMENTS:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown robot type: {robot_type}. Available types: {list(self.ROBOT_MEASUREMENTS.keys())}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.robot_type = robot_type
|
|
||||||
self.measurements = self.ROBOT_MEASUREMENTS[robot_type]
|
|
||||||
|
|
||||||
# Initialize all transformation matrices and screw axes
|
|
||||||
self._setup_transforms()
|
|
||||||
|
|
||||||
def _create_translation_matrix(self, x=0, y=0, z=0):
|
|
||||||
"""Create a 4x4 translation matrix."""
|
|
||||||
return np.array([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z], [0, 0, 0, 1]])
|
|
||||||
|
|
||||||
def _setup_transforms(self):
|
|
||||||
"""Setup all transformation matrices and screw axes for the robot."""
|
|
||||||
# Set up rotation matrices (constant across robot types)
|
|
||||||
|
|
||||||
# Gripper orientation
|
|
||||||
self.gripper_X0 = np.array(
|
|
||||||
[
|
|
||||||
[1, 0, 0, 0],
|
|
||||||
[0, 0, 1, 0],
|
|
||||||
[0, -1, 0, 0],
|
|
||||||
[0, 0, 0, 1],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wrist orientation
|
|
||||||
self.wrist_X0 = np.array(
|
|
||||||
[
|
|
||||||
[0, -1, 0, 0],
|
|
||||||
[1, 0, 0, 0],
|
|
||||||
[0, 0, 1, 0],
|
|
||||||
[0, 0, 0, 1],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Base orientation
|
|
||||||
self.base_X0 = np.array(
|
|
||||||
[
|
|
||||||
[0, 0, 1, 0],
|
|
||||||
[1, 0, 0, 0],
|
|
||||||
[0, 1, 0, 0],
|
|
||||||
[0, 0, 0, 1],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Gripper
|
|
||||||
# Screw axis of gripper frame wrt base frame
|
|
||||||
self.S_BG = np.array(
|
|
||||||
[
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
self.measurements["gripper"][2],
|
|
||||||
-self.measurements["gripper"][1],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Gripper origin to centroid transform
|
|
||||||
self.X_GoGc = self._create_translation_matrix(x=0.07)
|
|
||||||
|
|
||||||
# Gripper origin to tip transform
|
|
||||||
self.X_GoGt = self._create_translation_matrix(x=0.12)
|
|
||||||
|
|
||||||
# 0-position gripper frame pose wrt base
|
|
||||||
self.X_BoGo = self._create_translation_matrix(
|
|
||||||
x=self.measurements["gripper"][0],
|
|
||||||
y=self.measurements["gripper"][1],
|
|
||||||
z=self.measurements["gripper"][2],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wrist
|
|
||||||
# Screw axis of wrist frame wrt base frame
|
|
||||||
self.S_BR = np.array([0, 1, 0, -self.measurements["wrist"][2], 0, self.measurements["wrist"][0]])
|
|
||||||
|
|
||||||
# 0-position origin to centroid transform
|
|
||||||
self.X_RoRc = self._create_translation_matrix(x=0.0035, y=-0.002)
|
|
||||||
|
|
||||||
# 0-position wrist frame pose wrt base
|
|
||||||
self.X_BR = self._create_translation_matrix(
|
|
||||||
x=self.measurements["wrist"][0],
|
|
||||||
y=self.measurements["wrist"][1],
|
|
||||||
z=self.measurements["wrist"][2],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Forearm
|
|
||||||
# Screw axis of forearm frame wrt base frame
|
|
||||||
self.S_BF = np.array(
|
|
||||||
[
|
|
||||||
0,
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
-self.measurements["forearm"][2],
|
|
||||||
0,
|
|
||||||
self.measurements["forearm"][0],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Forearm origin + centroid transform
|
|
||||||
self.X_FoFc = self._create_translation_matrix(x=0.036) # spellchecker:disable-line
|
|
||||||
|
|
||||||
# 0-position forearm frame pose wrt base
|
|
||||||
self.X_BF = self._create_translation_matrix(
|
|
||||||
x=self.measurements["forearm"][0],
|
|
||||||
y=self.measurements["forearm"][1],
|
|
||||||
z=self.measurements["forearm"][2],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Humerus
|
|
||||||
# Screw axis of humerus frame wrt base frame
|
|
||||||
self.S_BH = np.array(
|
|
||||||
[
|
|
||||||
0,
|
|
||||||
-1,
|
|
||||||
0,
|
|
||||||
self.measurements["humerus"][2],
|
|
||||||
0,
|
|
||||||
-self.measurements["humerus"][0],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Humerus origin to centroid transform
|
|
||||||
self.X_HoHc = self._create_translation_matrix(x=0.0475)
|
|
||||||
|
|
||||||
# 0-position humerus frame pose wrt base
|
|
||||||
self.X_BH = self._create_translation_matrix(
|
|
||||||
x=self.measurements["humerus"][0],
|
|
||||||
y=self.measurements["humerus"][1],
|
|
||||||
z=self.measurements["humerus"][2],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Shoulder
|
|
||||||
# Screw axis of shoulder frame wrt Base frame
|
|
||||||
self.S_BS = np.array([0, 0, -1, 0, 0, 0])
|
|
||||||
|
|
||||||
# Shoulder origin to centroid transform
|
|
||||||
self.X_SoSc = self._create_translation_matrix(x=-0.017, z=0.0235)
|
|
||||||
|
|
||||||
# 0-position shoulder frame pose wrt base
|
|
||||||
self.X_BS = self._create_translation_matrix(
|
|
||||||
x=self.measurements["shoulder"][0],
|
|
||||||
y=self.measurements["shoulder"][1],
|
|
||||||
z=self.measurements["shoulder"][2],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Base
|
|
||||||
# Base origin to centroid transform
|
|
||||||
self.X_BoBc = self._create_translation_matrix(y=0.015)
|
|
||||||
|
|
||||||
# World to base transform
|
|
||||||
self.X_WoBo = self._create_translation_matrix(
|
|
||||||
x=self.measurements["base"][0],
|
|
||||||
y=self.measurements["base"][1],
|
|
||||||
z=self.measurements["base"][2],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pre-compute gripper post-multiplication matrix
|
|
||||||
self._fk_gripper_post = self.X_GoGc @ self.X_BoGo @ self.gripper_X0
|
|
||||||
|
|
||||||
def fk_base(self):
|
|
||||||
"""Forward kinematics for the base frame."""
|
|
||||||
return self.X_WoBo @ self.X_BoBc @ self.base_X0
|
|
||||||
|
|
||||||
def fk_shoulder(self, robot_pos_deg):
|
|
||||||
"""Forward kinematics for the shoulder frame."""
|
|
||||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
|
||||||
return self.X_WoBo @ screw_axis_to_transform(self.S_BS, robot_pos_rad[0]) @ self.X_SoSc @ self.X_BS
|
|
||||||
|
|
||||||
def fk_humerus(self, robot_pos_deg):
|
|
||||||
"""Forward kinematics for the humerus frame."""
|
|
||||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
|
||||||
return (
|
|
||||||
self.X_WoBo
|
|
||||||
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
|
|
||||||
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
|
|
||||||
@ self.X_HoHc
|
|
||||||
@ self.X_BH
|
|
||||||
)
|
|
||||||
|
|
||||||
def fk_forearm(self, robot_pos_deg):
|
|
||||||
"""Forward kinematics for the forearm frame."""
|
|
||||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
|
||||||
return (
|
|
||||||
self.X_WoBo
|
|
||||||
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
|
|
||||||
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
|
|
||||||
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
|
|
||||||
@ self.X_FoFc # spellchecker:disable-line
|
|
||||||
@ self.X_BF
|
|
||||||
)
|
|
||||||
|
|
||||||
def fk_wrist(self, robot_pos_deg):
|
|
||||||
"""Forward kinematics for the wrist frame."""
|
|
||||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
|
||||||
return (
|
|
||||||
self.X_WoBo
|
|
||||||
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
|
|
||||||
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
|
|
||||||
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
|
|
||||||
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
|
|
||||||
@ self.X_RoRc
|
|
||||||
@ self.X_BR
|
|
||||||
@ self.wrist_X0
|
|
||||||
)
|
|
||||||
|
|
||||||
def fk_gripper(self, robot_pos_deg):
|
|
||||||
"""Forward kinematics for the gripper frame."""
|
|
||||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
|
||||||
return (
|
|
||||||
self.X_WoBo
|
|
||||||
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
|
|
||||||
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
|
|
||||||
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
|
|
||||||
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
|
|
||||||
@ screw_axis_to_transform(self.S_BG, robot_pos_rad[4])
|
|
||||||
@ self._fk_gripper_post
|
|
||||||
)
|
|
||||||
|
|
||||||
def fk_gripper_tip(self, robot_pos_deg):
|
|
||||||
"""Forward kinematics for the gripper tip frame."""
|
|
||||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
|
||||||
return (
|
|
||||||
self.X_WoBo
|
|
||||||
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
|
|
||||||
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
|
|
||||||
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
|
|
||||||
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
|
|
||||||
@ screw_axis_to_transform(self.S_BG, robot_pos_rad[4])
|
|
||||||
@ self.X_GoGt
|
|
||||||
@ self.X_BoGo
|
|
||||||
@ self.gripper_X0
|
|
||||||
)
|
|
||||||
|
|
||||||
def compute_jacobian(self, robot_pos_deg, fk_func=None):
|
|
||||||
"""Finite differences to compute the Jacobian.
|
|
||||||
J(i, j) represents how the ith component of the end-effector's velocity changes wrt a small change
|
|
||||||
in the jth joint's velocity.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
robot_pos_deg: Current joint positions in degrees
|
|
||||||
fk_func: Forward kinematics function to use (defaults to fk_gripper)
|
|
||||||
"""
|
|
||||||
if fk_func is None:
|
|
||||||
fk_func = self.fk_gripper
|
|
||||||
|
|
||||||
eps = 1e-8
|
|
||||||
jac = np.zeros(shape=(6, 5))
|
|
||||||
delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64)
|
|
||||||
for el_ix in range(len(robot_pos_deg[:-1])):
|
|
||||||
delta *= 0
|
|
||||||
delta[el_ix] = eps / 2
|
|
||||||
Sdot = (
|
|
||||||
pose_difference_se3(
|
|
||||||
fk_func(robot_pos_deg[:-1] + delta),
|
|
||||||
fk_func(robot_pos_deg[:-1] - delta),
|
|
||||||
)
|
|
||||||
/ eps
|
|
||||||
)
|
|
||||||
jac[:, el_ix] = Sdot
|
|
||||||
return jac
|
|
||||||
|
|
||||||
def compute_positional_jacobian(self, robot_pos_deg, fk_func=None):
|
|
||||||
"""Finite differences to compute the positional Jacobian.
|
|
||||||
J(i, j) represents how the ith component of the end-effector's position changes wrt a small change
|
|
||||||
in the jth joint's velocity.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
robot_pos_deg: Current joint positions in degrees
|
|
||||||
fk_func: Forward kinematics function to use (defaults to fk_gripper)
|
|
||||||
"""
|
|
||||||
if fk_func is None:
|
|
||||||
fk_func = self.fk_gripper
|
|
||||||
|
|
||||||
eps = 1e-8
|
|
||||||
jac = np.zeros(shape=(3, 5))
|
|
||||||
delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64)
|
|
||||||
for el_ix in range(len(robot_pos_deg[:-1])):
|
|
||||||
delta *= 0
|
|
||||||
delta[el_ix] = eps / 2
|
|
||||||
Sdot = (
|
|
||||||
fk_func(robot_pos_deg[:-1] + delta)[:3, 3] - fk_func(robot_pos_deg[:-1] - delta)[:3, 3]
|
|
||||||
) / eps
|
|
||||||
jac[:, el_ix] = Sdot
|
|
||||||
return jac
|
|
||||||
|
|
||||||
def ik(self, current_joint_state, desired_ee_pose, position_only=True, fk_func=None):
|
|
||||||
"""Inverse kinematics using gradient descent.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
current_joint_state: Initial joint positions in degrees
|
|
||||||
desired_ee_pose: Target end-effector pose as a 4x4 transformation matrix
|
|
||||||
position_only: If True, only match end-effector position, not orientation
|
|
||||||
fk_func: Forward kinematics function to use (defaults to fk_gripper)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Joint positions in degrees that achieve the desired end-effector pose
|
|
||||||
"""
|
|
||||||
if fk_func is None:
|
|
||||||
fk_func = self.fk_gripper
|
|
||||||
|
|
||||||
# Do gradient descent.
|
|
||||||
max_iterations = 5
|
|
||||||
learning_rate = 1
|
|
||||||
for _ in range(max_iterations):
|
|
||||||
current_ee_pose = fk_func(current_joint_state)
|
|
||||||
if not position_only:
|
|
||||||
error = se3_error(desired_ee_pose, current_ee_pose)
|
|
||||||
jac = self.compute_jacobian(current_joint_state, fk_func)
|
|
||||||
else:
|
|
||||||
error = desired_ee_pose[:3, 3] - current_ee_pose[:3, 3]
|
|
||||||
jac = self.compute_positional_jacobian(current_joint_state, fk_func)
|
|
||||||
delta_angles = np.linalg.pinv(jac) @ error
|
|
||||||
current_joint_state[:-1] += learning_rate * delta_angles
|
|
||||||
|
|
||||||
if np.linalg.norm(error) < 5e-3:
|
|
||||||
return current_joint_state
|
|
||||||
return current_joint_state
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import time
|
|
||||||
|
|
||||||
def run_test(robot_type):
|
|
||||||
"""Run test suite for a specific robot type."""
|
|
||||||
print(f"\n--- Testing {robot_type.upper()} Robot ---")
|
|
||||||
|
|
||||||
# Initialize kinematics for this robot
|
|
||||||
robot = RobotKinematics(robot_type)
|
|
||||||
|
|
||||||
# Test 1: Forward kinematics consistency
|
|
||||||
print("Test 1: Forward kinematics consistency")
|
|
||||||
test_angles = np.array([30, 45, -30, 20, 10, 0]) # Example joint angles in degrees
|
|
||||||
|
|
||||||
# Calculate FK for different joints
|
|
||||||
shoulder_pose = robot.fk_shoulder(test_angles)
|
|
||||||
humerus_pose = robot.fk_humerus(test_angles)
|
|
||||||
forearm_pose = robot.fk_forearm(test_angles)
|
|
||||||
wrist_pose = robot.fk_wrist(test_angles)
|
|
||||||
gripper_pose = robot.fk_gripper(test_angles)
|
|
||||||
gripper_tip_pose = robot.fk_gripper_tip(test_angles)
|
|
||||||
|
|
||||||
# Check that poses form a consistent kinematic chain (positions should be progressively further from origin)
|
|
||||||
distances = [
|
|
||||||
np.linalg.norm(shoulder_pose[:3, 3]),
|
|
||||||
np.linalg.norm(humerus_pose[:3, 3]),
|
|
||||||
np.linalg.norm(forearm_pose[:3, 3]),
|
|
||||||
np.linalg.norm(wrist_pose[:3, 3]),
|
|
||||||
np.linalg.norm(gripper_pose[:3, 3]),
|
|
||||||
np.linalg.norm(gripper_tip_pose[:3, 3]),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Check if distances generally increase along the chain
|
|
||||||
is_consistent = all(distances[i] <= distances[i + 1] for i in range(len(distances) - 1))
|
|
||||||
print(f" Pose distances from origin: {[round(d, 3) for d in distances]}")
|
|
||||||
print(f" Kinematic chain consistency: {'PASSED' if is_consistent else 'FAILED'}")
|
|
||||||
|
|
||||||
# Test 2: Jacobian computation
|
|
||||||
print("Test 2: Jacobian computation")
|
|
||||||
jacobian = robot.compute_jacobian(test_angles)
|
|
||||||
positional_jacobian = robot.compute_positional_jacobian(test_angles)
|
|
||||||
|
|
||||||
# Check shapes
|
|
||||||
jacobian_shape_ok = jacobian.shape == (6, 5)
|
|
||||||
pos_jacobian_shape_ok = positional_jacobian.shape == (3, 5)
|
|
||||||
|
|
||||||
print(f" Jacobian shape: {'PASSED' if jacobian_shape_ok else 'FAILED'}")
|
|
||||||
print(f" Positional Jacobian shape: {'PASSED' if pos_jacobian_shape_ok else 'FAILED'}")
|
|
||||||
|
|
||||||
# Test 3: Inverse kinematics
|
|
||||||
print("Test 3: Inverse kinematics (position only)")
|
|
||||||
|
|
||||||
# Generate target pose from known joint angles
|
|
||||||
original_angles = np.array([10, 20, 30, -10, 5, 0])
|
|
||||||
target_pose = robot.fk_gripper(original_angles)
|
|
||||||
|
|
||||||
# Start IK from a different position
|
|
||||||
initial_guess = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
|
|
||||||
|
|
||||||
# Measure IK performance
|
|
||||||
start_time = time.time()
|
|
||||||
computed_angles = robot.ik(initial_guess.copy(), target_pose)
|
|
||||||
ik_time = time.time() - start_time
|
|
||||||
|
|
||||||
# Compute resulting pose from IK solution
|
|
||||||
result_pose = robot.fk_gripper(computed_angles)
|
|
||||||
|
|
||||||
# Calculate position error
|
|
||||||
pos_error = np.linalg.norm(target_pose[:3, 3] - result_pose[:3, 3])
|
|
||||||
passed = pos_error < 0.01 # Accept errors less than 1cm
|
|
||||||
|
|
||||||
print(f" IK computation time: {ik_time:.4f} seconds")
|
|
||||||
print(f" Position error: {pos_error:.4f}")
|
|
||||||
print(f" IK position accuracy: {'PASSED' if passed else 'FAILED'}")
|
|
||||||
|
|
||||||
return is_consistent and jacobian_shape_ok and pos_jacobian_shape_ok and passed
|
|
||||||
|
|
||||||
# Run tests for all robot types
|
|
||||||
results = {}
|
|
||||||
for robot_type in ["koch", "so100", "moss"]:
|
|
||||||
results[robot_type] = run_test(robot_type)
|
|
||||||
|
|
||||||
# Print overall summary
|
|
||||||
print("\n=== Test Summary ===")
|
|
||||||
all_passed = all(results.values())
|
|
||||||
for robot_type, passed in results.items():
|
|
||||||
print(f"{robot_type.upper()}: {'PASSED' if passed else 'FAILED'}")
|
|
||||||
print(f"\nOverall: {'ALL TESTS PASSED' if all_passed else 'SOME TESTS FAILED'}")
|
|
||||||
@@ -5,7 +5,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.scripts.server.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized
|
from lerobot.common.utils.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized
|
||||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||||
|
|
||||||
|
|
||||||
Reference in New Issue
Block a user