[PORT HIL-SERL] Refactor folders structure | Rebased version (#1178)

This commit is contained in:
Eugene Mironov
2025-06-02 14:46:56 +07:00
committed by AdilZouitine
parent 8feda920da
commit 6eeab64f8a
19 changed files with 270 additions and 776 deletions

View File

@@ -606,7 +606,7 @@ class SACObservationEncoder(nn.Module):
Usage patterns:
- Called in select_action() with normalize=True
- Called in learner_server.py's get_observation_features() to pre-compute features for all policy components
- Called in learner.py's get_observation_features() to pre-compute features for all policy components
- Called internally by forward() with normalize=False
Args:

View File

@@ -13,9 +13,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// To generate a classes for transport part (services_pb2.py and services_pb2_grpc.py) use the following command:
//
// python -m grpc_tools.protoc -I . --python_out=. --grpc_python_out=. lerobot/common/transport/services.proto
//
// The command should be launched from the root of the project.
syntax = "proto3";
package hil_serl;
package transport;
// LearnerService: the Actor calls this to push transitions.
// The Learner implements this service.

View File

@@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: lerobot/common/transport/services.proto
# Protobuf Python Version: 5.29.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
29,
0,
'',
'lerobot/common/transport/services.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\'lerobot/common/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xcc\x02\n\x0eLearnerService\x12I\n\x16SendInteractionMessage\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'lerobot.common.transport.services_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_TRANSFERSTATE']._serialized_start=305
_globals['_TRANSFERSTATE']._serialized_end=401
_globals['_TRANSITION']._serialized_start=54
_globals['_TRANSITION']._serialized_end=130
_globals['_PARAMETERS']._serialized_start=132
_globals['_PARAMETERS']._serialized_end=208
_globals['_INTERACTIONMESSAGE']._serialized_start=210
_globals['_INTERACTIONMESSAGE']._serialized_end=294
_globals['_EMPTY']._serialized_start=296
_globals['_EMPTY']._serialized_end=303
_globals['_LEARNERSERVICE']._serialized_start=404
_globals['_LEARNERSERVICE']._serialized_end=736
# @@protoc_insertion_point(module_scope)

View File

@@ -3,9 +3,9 @@
import grpc
import warnings
import hilserl_pb2 as hilserl__pb2
from lerobot.common.transport import services_pb2 as lerobot_dot_common_dot_transport_dot_services__pb2
GRPC_GENERATED_VERSION = '1.70.0'
GRPC_GENERATED_VERSION = '1.71.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False
@@ -18,7 +18,7 @@ except ImportError:
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in hilserl_pb2_grpc.py depends on'
+ f' but the generated code in lerobot/common/transport/services_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
@@ -37,29 +37,29 @@ class LearnerServiceStub(object):
channel: A grpc.Channel.
"""
self.SendInteractionMessage = channel.unary_unary(
'/hil_serl.LearnerService/SendInteractionMessage',
request_serializer=hilserl__pb2.InteractionMessage.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
'/transport.LearnerService/SendInteractionMessage',
request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString,
response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
_registered_method=True)
self.StreamParameters = channel.unary_stream(
'/hil_serl.LearnerService/StreamParameters',
request_serializer=hilserl__pb2.Empty.SerializeToString,
response_deserializer=hilserl__pb2.Parameters.FromString,
'/transport.LearnerService/StreamParameters',
request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Parameters.FromString,
_registered_method=True)
self.SendTransitions = channel.stream_unary(
'/hil_serl.LearnerService/SendTransitions',
request_serializer=hilserl__pb2.Transition.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
'/transport.LearnerService/SendTransitions',
request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Transition.SerializeToString,
response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
_registered_method=True)
self.SendInteractions = channel.stream_unary(
'/hil_serl.LearnerService/SendInteractions',
request_serializer=hilserl__pb2.InteractionMessage.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
'/transport.LearnerService/SendInteractions',
request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString,
response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
_registered_method=True)
self.Ready = channel.unary_unary(
'/hil_serl.LearnerService/Ready',
request_serializer=hilserl__pb2.Empty.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
'/transport.LearnerService/Ready',
request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
_registered_method=True)
@@ -104,34 +104,34 @@ def add_LearnerServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'SendInteractionMessage': grpc.unary_unary_rpc_method_handler(
servicer.SendInteractionMessage,
request_deserializer=hilserl__pb2.InteractionMessage.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.FromString,
response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
),
'StreamParameters': grpc.unary_stream_rpc_method_handler(
servicer.StreamParameters,
request_deserializer=hilserl__pb2.Empty.FromString,
response_serializer=hilserl__pb2.Parameters.SerializeToString,
request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Parameters.SerializeToString,
),
'SendTransitions': grpc.stream_unary_rpc_method_handler(
servicer.SendTransitions,
request_deserializer=hilserl__pb2.Transition.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Transition.FromString,
response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
),
'SendInteractions': grpc.stream_unary_rpc_method_handler(
servicer.SendInteractions,
request_deserializer=hilserl__pb2.InteractionMessage.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.FromString,
response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
),
'Ready': grpc.unary_unary_rpc_method_handler(
servicer.Ready,
request_deserializer=hilserl__pb2.Empty.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'hil_serl.LearnerService', rpc_method_handlers)
'transport.LearnerService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('hil_serl.LearnerService', rpc_method_handlers)
server.add_registered_method_handlers('transport.LearnerService', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
@@ -154,9 +154,9 @@ class LearnerService(object):
return grpc.experimental.unary_unary(
request,
target,
'/hil_serl.LearnerService/SendInteractionMessage',
hilserl__pb2.InteractionMessage.SerializeToString,
hilserl__pb2.Empty.FromString,
'/transport.LearnerService/SendInteractionMessage',
lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString,
lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
@@ -181,9 +181,9 @@ class LearnerService(object):
return grpc.experimental.unary_stream(
request,
target,
'/hil_serl.LearnerService/StreamParameters',
hilserl__pb2.Empty.SerializeToString,
hilserl__pb2.Parameters.FromString,
'/transport.LearnerService/StreamParameters',
lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
lerobot_dot_common_dot_transport_dot_services__pb2.Parameters.FromString,
options,
channel_credentials,
insecure,
@@ -208,9 +208,9 @@ class LearnerService(object):
return grpc.experimental.stream_unary(
request_iterator,
target,
'/hil_serl.LearnerService/SendTransitions',
hilserl__pb2.Transition.SerializeToString,
hilserl__pb2.Empty.FromString,
'/transport.LearnerService/SendTransitions',
lerobot_dot_common_dot_transport_dot_services__pb2.Transition.SerializeToString,
lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
@@ -235,9 +235,9 @@ class LearnerService(object):
return grpc.experimental.stream_unary(
request_iterator,
target,
'/hil_serl.LearnerService/SendInteractions',
hilserl__pb2.InteractionMessage.SerializeToString,
hilserl__pb2.Empty.FromString,
'/transport.LearnerService/SendInteractions',
lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString,
lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
@@ -262,9 +262,9 @@ class LearnerService(object):
return grpc.experimental.unary_unary(
request,
target,
'/hil_serl.LearnerService/Ready',
hilserl__pb2.Empty.SerializeToString,
hilserl__pb2.Empty.FromString,
'/transport.LearnerService/Ready',
lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
options,
channel_credentials,
insecure,

View File

@@ -23,8 +23,8 @@ from typing import Any
import torch
from lerobot.scripts.server import hilserl_pb2
from lerobot.scripts.server.utils import Transition
from lerobot.common.transport import services_pb2
from lerobot.common.utils.transition import Transition
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
@@ -47,12 +47,12 @@ def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = ""
logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with")
while sent_bytes < size_in_bytes:
transfer_state = hilserl_pb2.TransferState.TRANSFER_MIDDLE
transfer_state = services_pb2.TransferState.TRANSFER_MIDDLE
if sent_bytes + CHUNK_SIZE >= size_in_bytes:
transfer_state = hilserl_pb2.TransferState.TRANSFER_END
transfer_state = services_pb2.TransferState.TRANSFER_END
elif sent_bytes == 0:
transfer_state = hilserl_pb2.TransferState.TRANSFER_BEGIN
transfer_state = services_pb2.TransferState.TRANSFER_BEGIN
size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
chunk = buffer.read(size_to_read)
@@ -75,18 +75,18 @@ def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_p
logging.info(f"{log_prefix} Shutting down receiver")
return
if item.transfer_state == hilserl_pb2.TransferState.TRANSFER_BEGIN:
if item.transfer_state == services_pb2.TransferState.TRANSFER_BEGIN:
bytes_buffer.seek(0)
bytes_buffer.truncate(0)
bytes_buffer.write(item.data)
logging.debug(f"{log_prefix} Received data at step 0")
step = 0
continue
elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_MIDDLE:
elif item.transfer_state == services_pb2.TransferState.TRANSFER_MIDDLE:
bytes_buffer.write(item.data)
step += 1
logging.debug(f"{log_prefix} Received data at step {step}")
elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_END:
elif item.transfer_state == services_pb2.TransferState.TRANSFER_END:
bytes_buffer.write(item.data)
logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")

View File

@@ -23,7 +23,7 @@ import torch.nn.functional as F # noqa: N812
from tqdm import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.scripts.server.utils import Transition
from lerobot.common.utils.transition import Transition
class BatchTransition(TypedDict):

View File

@@ -0,0 +1,53 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import signal
import sys
shutdown_event_counter = 0
def setup_process_handlers(use_threads: bool) -> any:
if use_threads:
from threading import Event
else:
from multiprocessing import Event
shutdown_event = Event()
# Define signal handler
def signal_handler(signum, frame):
logging.info("Shutdown signal received. Cleaning up...")
shutdown_event.set()
global shutdown_event_counter
shutdown_event_counter += 1
if shutdown_event_counter > 1:
logging.info("Force shutdown")
sys.exit(1)
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Termination request (kill)
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed/Hangup
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
def signal_handler(signum, frame):
logging.info("Shutdown signal received. Cleaning up...")
shutdown_event.set()
return shutdown_event

View File

@@ -0,0 +1,35 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from queue import Empty, Queue
def get_last_item_from_queue(queue: Queue):
item = queue.get()
counter = 1
# Drain queue and keep only the most recent parameters
try:
while True:
item = queue.get_nowait()
counter += 1
except Empty:
pass
logging.debug(f"Drained {counter} items from queue")
return item

View File

@@ -1,7 +1,6 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,64 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import signal
import sys
from queue import Empty
from typing import TypedDict
import torch
from torch.multiprocessing import Queue
shutdown_event_counter = 0
def setup_process_handlers(use_threads: bool) -> any:
if use_threads:
from threading import Event
else:
from multiprocessing import Event
shutdown_event = Event()
# Define signal handler
def signal_handler(signum, frame):
logging.info("Shutdown signal received. Cleaning up...")
shutdown_event.set()
global shutdown_event_counter
shutdown_event_counter += 1
if shutdown_event_counter > 1:
logging.info("Force shutdown")
sys.exit(1)
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Termination request (kill)
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed/Hangup
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
def signal_handler(signum, frame):
logging.info("Shutdown signal received. Cleaning up...")
shutdown_event.set()
return shutdown_event
def get_last_item_from_queue(queue: Queue):
item = queue.get()
counter = 1
# Drain queue and keep only the most recent parameters
try:
while True:
item = queue.get_nowait()
counter += 1
except Empty:
pass
logging.debug(f"Drained {counter} items from queue")
return item
class Transition(TypedDict):

View File

@@ -24,12 +24,12 @@ Examples of usage:
- Start an actor server for real robot training with human-in-the-loop intervention:
```bash
python lerobot/scripts/server/actor_server.py --config_path lerobot/configs/train_config_hilserl_so100.json
python lerobot/scripts/rl/actor.py --config_path lerobot/configs/train_config_hilserl_so100.json
```
- Run with a specific robot type for a pick and place task:
```bash
python lerobot/scripts/server/actor_server.py \
python lerobot/scripts/rl/actor.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--robot.type=so100 \
--task=pick_and_place
@@ -37,7 +37,7 @@ python lerobot/scripts/server/actor_server.py \
- Set a custom workspace bound for the robot's end-effector:
```bash
python lerobot/scripts/server/actor_server.py \
python lerobot/scripts/rl/actor.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--env.ee_action_space_params.bounds.max="[0.24, 0.20, 0.10]" \
--env.ee_action_space_params.bounds.min="[0.16, -0.08, 0.03]"
@@ -45,7 +45,7 @@ python lerobot/scripts/server/actor_server.py \
- Run with specific camera crop parameters:
```bash
python lerobot/scripts/server/actor_server.py \
python lerobot/scripts/rl/actor.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--env.crop_params_dict="{'observation.images.side': [180, 207, 180, 200], 'observation.images.front': [180, 250, 120, 150]}"
```
@@ -85,8 +85,23 @@ from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.robots import so100_follower_end_effector # noqa: F401
from lerobot.common.teleoperators import gamepad, so100_leader # noqa: F401
from lerobot.common.transport import services_pb2, services_pb2_grpc
from lerobot.common.transport.utils import (
bytes_to_state_dict,
python_object_to_bytes,
receive_bytes_in_chunks,
send_bytes_in_chunks,
transitions_to_bytes,
)
from lerobot.common.utils.process import setup_process_handlers
from lerobot.common.utils.queue import get_last_item_from_queue
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.robot_utils import busy_wait
from lerobot.common.utils.transition import (
Transition,
move_state_dict_to_device,
move_transition_to_device,
)
from lerobot.common.utils.utils import (
TimerManager,
get_safe_torch_device,
@@ -94,22 +109,8 @@ from lerobot.common.utils.utils import (
)
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service
from lerobot.scripts.server.buffer import Transition
from lerobot.scripts.server.gym_manipulator import make_robot_env
from lerobot.scripts.server.network_utils import (
bytes_to_state_dict,
python_object_to_bytes,
receive_bytes_in_chunks,
send_bytes_in_chunks,
transitions_to_bytes,
)
from lerobot.scripts.server.utils import (
get_last_item_from_queue,
move_state_dict_to_device,
move_transition_to_device,
setup_process_handlers,
)
from lerobot.scripts.rl import learner_service
from lerobot.scripts.rl.gym_manipulator import make_robot_env
ACTOR_SHUTDOWN_TIMEOUT = 30
@@ -386,14 +387,14 @@ def act_with_policy(
def establish_learner_connection(
stub: hilserl_pb2_grpc.LearnerServiceStub,
stub: services_pb2_grpc.LearnerServiceStub,
shutdown_event: Event, # type: ignore
attempts: int = 30,
):
"""Establish a connection with the learner.
Args:
stub (hilserl_pb2_grpc.LearnerServiceStub): The stub to use for the connection.
stub (services_pb2_grpc.LearnerServiceStub): The stub to use for the connection.
shutdown_event (Event): The event to check if the connection should be established.
attempts (int): The number of attempts to establish the connection.
Returns:
@@ -407,7 +408,7 @@ def establish_learner_connection(
# Force a connection attempt and check state
try:
logging.info("[ACTOR] Send ready message to Learner")
if stub.Ready(hilserl_pb2.Empty()) == hilserl_pb2.Empty():
if stub.Ready(services_pb2.Empty()) == services_pb2.Empty():
return True
except grpc.RpcError as e:
logging.error(f"[ACTOR] Waiting for Learner to be ready... {e}")
@@ -419,7 +420,7 @@ def establish_learner_connection(
def learner_service_client(
host: str = "127.0.0.1",
port: int = 50051,
) -> tuple[hilserl_pb2_grpc.LearnerServiceStub, grpc.Channel]:
) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]:
import json
"""
@@ -458,7 +459,7 @@ def learner_service_client(
("grpc.service_config", service_config_json),
],
)
stub = hilserl_pb2_grpc.LearnerServiceStub(channel)
stub = services_pb2_grpc.LearnerServiceStub(channel)
logging.info("[ACTOR] Learner service client created")
return stub, channel
@@ -467,7 +468,7 @@ def receive_policy(
cfg: TrainPipelineConfig,
parameters_queue: Queue,
shutdown_event: Event, # type: ignore
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
):
"""Receive parameters from the learner.
@@ -499,7 +500,7 @@ def receive_policy(
)
try:
iterator = learner_client.StreamParameters(hilserl_pb2.Empty())
iterator = learner_client.StreamParameters(services_pb2.Empty())
receive_bytes_in_chunks(
iterator,
parameters_queue,
@@ -519,9 +520,9 @@ def send_transitions(
cfg: TrainPipelineConfig,
transitions_queue: Queue,
shutdown_event: any, # Event,
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> hilserl_pb2.Empty:
) -> services_pb2.Empty:
"""
Sends transitions to the learner.
@@ -530,7 +531,7 @@ def send_transitions(
- Transition Data:
- A batch of transitions (observation, action, reward, next observation) is collected.
- Transitions are moved to the CPU and serialized using PyTorch.
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
- The serialized data is wrapped in a `services_pb2.Transition` message and sent to the learner.
"""
if not use_threads(cfg):
@@ -569,9 +570,9 @@ def send_interactions(
cfg: TrainPipelineConfig,
interactions_queue: Queue,
shutdown_event: Event, # type: ignore
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> hilserl_pb2.Empty:
) -> services_pb2.Empty:
"""
Sends interactions to the learner.
@@ -614,7 +615,7 @@ def send_interactions(
logging.info("[ACTOR] Interactions process stopped")
def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilserl_pb2.Empty: # type: ignore
def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> services_pb2.Empty: # type: ignore
while not shutdown_event.is_set():
try:
message = transitions_queue.get(block=True, timeout=5)
@@ -623,16 +624,16 @@ def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilse
continue
yield from send_bytes_in_chunks(
message, hilserl_pb2.Transition, log_prefix="[ACTOR] Send transitions"
message, services_pb2.Transition, log_prefix="[ACTOR] Send transitions"
)
return hilserl_pb2.Empty()
return services_pb2.Empty()
def interactions_stream(
shutdown_event: Event, # type: ignore
interactions_queue: Queue,
) -> hilserl_pb2.Empty:
) -> services_pb2.Empty:
while not shutdown_event.is_set():
try:
message = interactions_queue.get(block=True, timeout=5)
@@ -642,11 +643,11 @@ def interactions_stream(
yield from send_bytes_in_chunks(
message,
hilserl_pb2.InteractionMessage,
services_pb2.InteractionMessage,
log_prefix="[ACTOR] Send interactions",
)
return hilserl_pb2.Empty()
return services_pb2.Empty()
#################################################

View File

@@ -25,12 +25,12 @@ Examples of usage:
- Start a learner server for training:
```bash
python lerobot/scripts/server/learner_server.py --config_path lerobot/configs/train_config_hilserl_so100.json
python lerobot/scripts/rl/learner.py --config_path lerobot/configs/train_config_hilserl_so100.json
```
- Run with specific SAC hyperparameters:
```bash
python lerobot/scripts/server/learner_server.py \
python lerobot/scripts/rl/learner.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--learner.sac.alpha=0.1 \
--learner.sac.gamma=0.99
@@ -38,7 +38,7 @@ python lerobot/scripts/server/learner_server.py \
- Run with a specific dataset and wandb logging:
```bash
python lerobot/scripts/server/learner_server.py \
python lerobot/scripts/rl/learner.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--dataset.repo_id=username/pick_lift_cube \
--wandb.enable=true \
@@ -47,14 +47,14 @@ python lerobot/scripts/server/learner_server.py \
- Run with a pretrained policy for fine-tuning:
```bash
python lerobot/scripts/server/learner_server.py \
python lerobot/scripts/rl/learner.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--pretrained_policy_name_or_path=outputs/previous_training/checkpoints/080000/pretrained_model
```
- Run with a reward classifier model:
```bash
python lerobot/scripts/server/learner_server.py \
python lerobot/scripts/rl/learner.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--reward_classifier_pretrained_path=outputs/reward_model/best_model
```
@@ -84,7 +84,6 @@ from pathlib import Path
from pprint import pformat
import grpc
import hilserl_pb2_grpc # type: ignore
import torch
from termcolor import colored
from torch import nn
@@ -103,6 +102,14 @@ from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.robots import so100_follower_end_effector # noqa: F401
from lerobot.common.teleoperators import gamepad, so100_leader # noqa: F401
from lerobot.common.transport import services_pb2_grpc
from lerobot.common.transport.utils import (
bytes_to_python_object,
bytes_to_transitions,
state_to_bytes,
)
from lerobot.common.utils.buffer import ReplayBuffer, concatenate_batch_transitions
from lerobot.common.utils.process import setup_process_handlers
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.train_utils import (
get_step_checkpoint_dir,
@@ -112,6 +119,7 @@ from lerobot.common.utils.train_utils import (
from lerobot.common.utils.train_utils import (
load_training_state as utils_load_training_state,
)
from lerobot.common.utils.transition import move_state_dict_to_device, move_transition_to_device
from lerobot.common.utils.utils import (
format_big_number,
get_safe_torch_device,
@@ -120,18 +128,7 @@ from lerobot.common.utils.utils import (
from lerobot.common.utils.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.server import learner_service
from lerobot.scripts.server.buffer import ReplayBuffer, concatenate_batch_transitions
from lerobot.scripts.server.network_utils import (
bytes_to_python_object,
bytes_to_transitions,
state_to_bytes,
)
from lerobot.scripts.server.utils import (
move_state_dict_to_device,
move_transition_to_device,
setup_process_handlers,
)
from lerobot.scripts.rl import learner_service
LOG_PREFIX = "[LEARNER]"
@@ -244,7 +241,7 @@ def start_learner_threads(
concurrency_entity = Process
communication_process = concurrency_entity(
target=start_learner_server,
target=start_learner,
args=(
parameters_queue,
transition_queue,
@@ -643,7 +640,7 @@ def add_actor_information_and_train(
)
def start_learner_server(
def start_learner(
parameters_queue: Queue,
transition_queue: Queue,
interaction_message_queue: Queue,
@@ -666,7 +663,7 @@ def start_learner_server(
# Create a process-specific log file
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"learner_server_process_{os.getpid()}.log")
log_file = os.path.join(log_dir, f"learner_process_{os.getpid()}.log")
# Initialize logging with explicit log file
init_logging(log_file=log_file, display_pid=True)
@@ -693,7 +690,7 @@ def start_learner_server(
],
)
hilserl_pb2_grpc.add_LearnerServiceServicer_to_server(
services_pb2_grpc.add_LearnerServiceServicer_to_server(
service,
server,
)

View File

@@ -1,17 +1,21 @@
import logging
from multiprocessing import Event, Queue
import hilserl_pb2 # type: ignore
import hilserl_pb2_grpc # type: ignore
from lerobot.scripts.server.network_utils import receive_bytes_in_chunks, send_bytes_in_chunks
from lerobot.common.transport import services_pb2, services_pb2_grpc
from lerobot.common.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
SHUTDOWN_TIMEOUT = 10
class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
class LearnerService(services_pb2_grpc.LearnerServiceServicer):
"""
Implementation of the LearnerService gRPC service
This service is used to send parameters to the Actor and receive transitions and interactions from the Actor
check transport.proto for the gRPC service definition
"""
def __init__(
self,
shutdown_event: Event, # type: ignore
@@ -36,7 +40,7 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
yield from send_bytes_in_chunks(
buffer,
hilserl_pb2.Parameters,
services_pb2.Parameters,
log_prefix="[LEARNER] Sending parameters",
silent=True,
)
@@ -46,7 +50,7 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
self.shutdown_event.wait(self.seconds_between_pushes)
logging.info("[LEARNER] Stream parameters finished")
return hilserl_pb2.Empty()
return services_pb2.Empty()
def SendTransitions(self, request_iterator, _context): # noqa: N802
# TODO: authorize the request
@@ -60,7 +64,7 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
)
logging.debug("[LEARNER] Finished receiving transitions")
return hilserl_pb2.Empty()
return services_pb2.Empty()
def SendInteractions(self, request_iterator, _context): # noqa: N802
# TODO: authorize the request
@@ -74,7 +78,7 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
)
logging.debug("[LEARNER] Finished receiving interactions")
return hilserl_pb2.Empty()
return services_pb2.Empty()
def Ready(self, request, context): # noqa: N802
return hilserl_pb2.Empty()
return services_pb2.Empty()

View File

@@ -1,46 +0,0 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: hilserl.proto
# Protobuf Python Version: 5.29.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
29,
0,
'',
'hilserl.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"K\n\nTransition\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"K\n\nParameters\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"S\n\x12InteractionMessage\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xc2\x02\n\x0eLearnerService\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty\x12;\n\x10StreamParameters\x12\x0f.hil_serl.Empty\x1a\x14.hil_serl.Parameters0\x01\x12:\n\x0fSendTransitions\x12\x14.hil_serl.Transition\x1a\x0f.hil_serl.Empty(\x01\x12\x43\n\x10SendInteractions\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty(\x01\x12)\n\x05Ready\x12\x0f.hil_serl.Empty\x1a\x0f.hil_serl.Emptyb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'hilserl_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_TRANSFERSTATE']._serialized_start=275
_globals['_TRANSFERSTATE']._serialized_end=371
_globals['_TRANSITION']._serialized_start=27
_globals['_TRANSITION']._serialized_end=102
_globals['_PARAMETERS']._serialized_start=104
_globals['_PARAMETERS']._serialized_end=179
_globals['_INTERACTIONMESSAGE']._serialized_start=181
_globals['_INTERACTIONMESSAGE']._serialized_end=264
_globals['_EMPTY']._serialized_start=266
_globals['_EMPTY']._serialized_end=273
_globals['_LEARNERSERVICE']._serialized_start=374
_globals['_LEARNERSERVICE']._serialized_end=696
# @@protoc_insertion_point(module_scope)

View File

@@ -1,546 +0,0 @@
# ruff: noqa: N806, N815, N803
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from scipy.spatial.transform import Rotation
def skew_symmetric(w):
"""Creates the skew-symmetric matrix from a 3D vector."""
return np.array([[0, -w[2], w[1]], [w[2], 0, -w[0]], [-w[1], w[0], 0]])
def rodrigues_rotation(w, theta):
"""Computes the rotation matrix using Rodrigues' formula."""
w_hat = skew_symmetric(w)
return np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
def screw_axis_to_transform(S, theta):
"""Converts a screw axis to a 4x4 transformation matrix."""
S_w = S[:3]
S_v = S[3:]
if np.allclose(S_w, 0) and np.linalg.norm(S_v) == 1: # Pure translation
T = np.eye(4)
T[:3, 3] = S_v * theta
elif np.linalg.norm(S_w) == 1: # Rotation and translation
w_hat = skew_symmetric(S_w)
R = np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
t = (np.eye(3) * theta + (1 - np.cos(theta)) * w_hat + (theta - np.sin(theta)) * w_hat @ w_hat) @ S_v
T = np.eye(4)
T[:3, :3] = R
T[:3, 3] = t
else:
raise ValueError("Invalid screw axis parameters")
return T
def pose_difference_se3(pose1, pose2):
"""
Calculates the SE(3) difference between two 4x4 homogeneous transformation matrices.
SE(3) (Special Euclidean Group) represents rigid body transformations in 3D space, combining rotation (SO(3)) and translation.
Each 4x4 matrix has the following structure, a 3x3 rotation matrix in the top-left and a 3x1 translation vector in the top-right:
[R11 R12 R13 tx]
[R21 R22 R23 ty]
[R31 R32 R33 tz]
[ 0 0 0 1]
where Rij is the 3x3 rotation matrix and [tx,ty,tz] is the translation vector.
pose1 - pose2
Args:
pose1: A 4x4 numpy array representing the first pose.
pose2: A 4x4 numpy array representing the second pose.
Returns:
A tuple (translation_diff, rotation_diff) where:
- translation_diff is a 3x1 numpy array representing the translational difference.
- rotation_diff is a 3x1 numpy array representing the rotational difference in axis-angle representation.
"""
# Extract rotation matrices from poses
R1 = pose1[:3, :3]
R2 = pose2[:3, :3]
# Calculate translational difference
translation_diff = pose1[:3, 3] - pose2[:3, 3]
# Calculate rotational difference using scipy's Rotation library
R_diff = Rotation.from_matrix(R1 @ R2.T)
rotation_diff = R_diff.as_rotvec() # Convert to axis-angle representation
return np.concatenate([translation_diff, rotation_diff])
def se3_error(target_pose, current_pose):
pos_error = target_pose[:3, 3] - current_pose[:3, 3]
R_target = target_pose[:3, :3]
R_current = current_pose[:3, :3]
R_error = R_target @ R_current.T
rot_error = Rotation.from_matrix(R_error).as_rotvec()
return np.concatenate([pos_error, rot_error])
class RobotKinematics:
"""Robot kinematics class supporting multiple robot models."""
# Robot measurements dictionary
ROBOT_MEASUREMENTS = {
"koch": {
"gripper": [0.239, -0.001, 0.024],
"wrist": [0.209, 0, 0.024],
"forearm": [0.108, 0, 0.02],
"humerus": [0, 0, 0.036],
"shoulder": [0, 0, 0],
"base": [0, 0, 0.02],
},
"so100": {
"gripper": [0.320, 0, 0.050],
"wrist": [0.278, 0, 0.050],
"forearm": [0.143, 0, 0.044],
"humerus": [0.031, 0, 0.072],
"shoulder": [0, 0, 0],
"base": [0, 0, 0.02],
},
"moss": {
"gripper": [0.246, 0.013, 0.111],
"wrist": [0.245, 0.002, 0.064],
"forearm": [0.122, 0, 0.064],
"humerus": [0.001, 0.001, 0.063],
"shoulder": [0, 0, 0],
"base": [0, 0, 0.02],
},
}
def __init__(self, robot_type="so100"):
"""Initialize kinematics for the specified robot type.
Args:
robot_type: String specifying the robot model ("koch", "so100", or "moss")
"""
if robot_type not in self.ROBOT_MEASUREMENTS:
raise ValueError(
f"Unknown robot type: {robot_type}. Available types: {list(self.ROBOT_MEASUREMENTS.keys())}"
)
self.robot_type = robot_type
self.measurements = self.ROBOT_MEASUREMENTS[robot_type]
# Initialize all transformation matrices and screw axes
self._setup_transforms()
def _create_translation_matrix(self, x=0, y=0, z=0):
"""Create a 4x4 translation matrix."""
return np.array([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z], [0, 0, 0, 1]])
def _setup_transforms(self):
"""Setup all transformation matrices and screw axes for the robot."""
# Set up rotation matrices (constant across robot types)
# Gripper orientation
self.gripper_X0 = np.array(
[
[1, 0, 0, 0],
[0, 0, 1, 0],
[0, -1, 0, 0],
[0, 0, 0, 1],
]
)
# Wrist orientation
self.wrist_X0 = np.array(
[
[0, -1, 0, 0],
[1, 0, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
]
)
# Base orientation
self.base_X0 = np.array(
[
[0, 0, 1, 0],
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 1],
]
)
# Gripper
# Screw axis of gripper frame wrt base frame
self.S_BG = np.array(
[
1,
0,
0,
0,
self.measurements["gripper"][2],
-self.measurements["gripper"][1],
]
)
# Gripper origin to centroid transform
self.X_GoGc = self._create_translation_matrix(x=0.07)
# Gripper origin to tip transform
self.X_GoGt = self._create_translation_matrix(x=0.12)
# 0-position gripper frame pose wrt base
self.X_BoGo = self._create_translation_matrix(
x=self.measurements["gripper"][0],
y=self.measurements["gripper"][1],
z=self.measurements["gripper"][2],
)
# Wrist
# Screw axis of wrist frame wrt base frame
self.S_BR = np.array([0, 1, 0, -self.measurements["wrist"][2], 0, self.measurements["wrist"][0]])
# 0-position origin to centroid transform
self.X_RoRc = self._create_translation_matrix(x=0.0035, y=-0.002)
# 0-position wrist frame pose wrt base
self.X_BR = self._create_translation_matrix(
x=self.measurements["wrist"][0],
y=self.measurements["wrist"][1],
z=self.measurements["wrist"][2],
)
# Forearm
# Screw axis of forearm frame wrt base frame
self.S_BF = np.array(
[
0,
1,
0,
-self.measurements["forearm"][2],
0,
self.measurements["forearm"][0],
]
)
# Forearm origin + centroid transform
self.X_FoFc = self._create_translation_matrix(x=0.036) # spellchecker:disable-line
# 0-position forearm frame pose wrt base
self.X_BF = self._create_translation_matrix(
x=self.measurements["forearm"][0],
y=self.measurements["forearm"][1],
z=self.measurements["forearm"][2],
)
# Humerus
# Screw axis of humerus frame wrt base frame
self.S_BH = np.array(
[
0,
-1,
0,
self.measurements["humerus"][2],
0,
-self.measurements["humerus"][0],
]
)
# Humerus origin to centroid transform
self.X_HoHc = self._create_translation_matrix(x=0.0475)
# 0-position humerus frame pose wrt base
self.X_BH = self._create_translation_matrix(
x=self.measurements["humerus"][0],
y=self.measurements["humerus"][1],
z=self.measurements["humerus"][2],
)
# Shoulder
# Screw axis of shoulder frame wrt Base frame
self.S_BS = np.array([0, 0, -1, 0, 0, 0])
# Shoulder origin to centroid transform
self.X_SoSc = self._create_translation_matrix(x=-0.017, z=0.0235)
# 0-position shoulder frame pose wrt base
self.X_BS = self._create_translation_matrix(
x=self.measurements["shoulder"][0],
y=self.measurements["shoulder"][1],
z=self.measurements["shoulder"][2],
)
# Base
# Base origin to centroid transform
self.X_BoBc = self._create_translation_matrix(y=0.015)
# World to base transform
self.X_WoBo = self._create_translation_matrix(
x=self.measurements["base"][0],
y=self.measurements["base"][1],
z=self.measurements["base"][2],
)
# Pre-compute gripper post-multiplication matrix
self._fk_gripper_post = self.X_GoGc @ self.X_BoGo @ self.gripper_X0
def fk_base(self):
"""Forward kinematics for the base frame."""
return self.X_WoBo @ self.X_BoBc @ self.base_X0
def fk_shoulder(self, robot_pos_deg):
"""Forward kinematics for the shoulder frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return self.X_WoBo @ screw_axis_to_transform(self.S_BS, robot_pos_rad[0]) @ self.X_SoSc @ self.X_BS
def fk_humerus(self, robot_pos_deg):
"""Forward kinematics for the humerus frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return (
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
@ self.X_HoHc
@ self.X_BH
)
def fk_forearm(self, robot_pos_deg):
"""Forward kinematics for the forearm frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return (
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
@ self.X_FoFc # spellchecker:disable-line
@ self.X_BF
)
def fk_wrist(self, robot_pos_deg):
"""Forward kinematics for the wrist frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return (
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
@ self.X_RoRc
@ self.X_BR
@ self.wrist_X0
)
def fk_gripper(self, robot_pos_deg):
"""Forward kinematics for the gripper frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return (
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
@ screw_axis_to_transform(self.S_BG, robot_pos_rad[4])
@ self._fk_gripper_post
)
def fk_gripper_tip(self, robot_pos_deg):
"""Forward kinematics for the gripper tip frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return (
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
@ screw_axis_to_transform(self.S_BG, robot_pos_rad[4])
@ self.X_GoGt
@ self.X_BoGo
@ self.gripper_X0
)
def compute_jacobian(self, robot_pos_deg, fk_func=None):
"""Finite differences to compute the Jacobian.
J(i, j) represents how the ith component of the end-effector's velocity changes wrt a small change
in the jth joint's velocity.
Args:
robot_pos_deg: Current joint positions in degrees
fk_func: Forward kinematics function to use (defaults to fk_gripper)
"""
if fk_func is None:
fk_func = self.fk_gripper
eps = 1e-8
jac = np.zeros(shape=(6, 5))
delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64)
for el_ix in range(len(robot_pos_deg[:-1])):
delta *= 0
delta[el_ix] = eps / 2
Sdot = (
pose_difference_se3(
fk_func(robot_pos_deg[:-1] + delta),
fk_func(robot_pos_deg[:-1] - delta),
)
/ eps
)
jac[:, el_ix] = Sdot
return jac
def compute_positional_jacobian(self, robot_pos_deg, fk_func=None):
"""Finite differences to compute the positional Jacobian.
J(i, j) represents how the ith component of the end-effector's position changes wrt a small change
in the jth joint's velocity.
Args:
robot_pos_deg: Current joint positions in degrees
fk_func: Forward kinematics function to use (defaults to fk_gripper)
"""
if fk_func is None:
fk_func = self.fk_gripper
eps = 1e-8
jac = np.zeros(shape=(3, 5))
delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64)
for el_ix in range(len(robot_pos_deg[:-1])):
delta *= 0
delta[el_ix] = eps / 2
Sdot = (
fk_func(robot_pos_deg[:-1] + delta)[:3, 3] - fk_func(robot_pos_deg[:-1] - delta)[:3, 3]
) / eps
jac[:, el_ix] = Sdot
return jac
def ik(self, current_joint_state, desired_ee_pose, position_only=True, fk_func=None):
"""Inverse kinematics using gradient descent.
Args:
current_joint_state: Initial joint positions in degrees
desired_ee_pose: Target end-effector pose as a 4x4 transformation matrix
position_only: If True, only match end-effector position, not orientation
fk_func: Forward kinematics function to use (defaults to fk_gripper)
Returns:
Joint positions in degrees that achieve the desired end-effector pose
"""
if fk_func is None:
fk_func = self.fk_gripper
# Do gradient descent.
max_iterations = 5
learning_rate = 1
for _ in range(max_iterations):
current_ee_pose = fk_func(current_joint_state)
if not position_only:
error = se3_error(desired_ee_pose, current_ee_pose)
jac = self.compute_jacobian(current_joint_state, fk_func)
else:
error = desired_ee_pose[:3, 3] - current_ee_pose[:3, 3]
jac = self.compute_positional_jacobian(current_joint_state, fk_func)
delta_angles = np.linalg.pinv(jac) @ error
current_joint_state[:-1] += learning_rate * delta_angles
if np.linalg.norm(error) < 5e-3:
return current_joint_state
return current_joint_state
if __name__ == "__main__":
import time
def run_test(robot_type):
"""Run test suite for a specific robot type."""
print(f"\n--- Testing {robot_type.upper()} Robot ---")
# Initialize kinematics for this robot
robot = RobotKinematics(robot_type)
# Test 1: Forward kinematics consistency
print("Test 1: Forward kinematics consistency")
test_angles = np.array([30, 45, -30, 20, 10, 0]) # Example joint angles in degrees
# Calculate FK for different joints
shoulder_pose = robot.fk_shoulder(test_angles)
humerus_pose = robot.fk_humerus(test_angles)
forearm_pose = robot.fk_forearm(test_angles)
wrist_pose = robot.fk_wrist(test_angles)
gripper_pose = robot.fk_gripper(test_angles)
gripper_tip_pose = robot.fk_gripper_tip(test_angles)
# Check that poses form a consistent kinematic chain (positions should be progressively further from origin)
distances = [
np.linalg.norm(shoulder_pose[:3, 3]),
np.linalg.norm(humerus_pose[:3, 3]),
np.linalg.norm(forearm_pose[:3, 3]),
np.linalg.norm(wrist_pose[:3, 3]),
np.linalg.norm(gripper_pose[:3, 3]),
np.linalg.norm(gripper_tip_pose[:3, 3]),
]
# Check if distances generally increase along the chain
is_consistent = all(distances[i] <= distances[i + 1] for i in range(len(distances) - 1))
print(f" Pose distances from origin: {[round(d, 3) for d in distances]}")
print(f" Kinematic chain consistency: {'PASSED' if is_consistent else 'FAILED'}")
# Test 2: Jacobian computation
print("Test 2: Jacobian computation")
jacobian = robot.compute_jacobian(test_angles)
positional_jacobian = robot.compute_positional_jacobian(test_angles)
# Check shapes
jacobian_shape_ok = jacobian.shape == (6, 5)
pos_jacobian_shape_ok = positional_jacobian.shape == (3, 5)
print(f" Jacobian shape: {'PASSED' if jacobian_shape_ok else 'FAILED'}")
print(f" Positional Jacobian shape: {'PASSED' if pos_jacobian_shape_ok else 'FAILED'}")
# Test 3: Inverse kinematics
print("Test 3: Inverse kinematics (position only)")
# Generate target pose from known joint angles
original_angles = np.array([10, 20, 30, -10, 5, 0])
target_pose = robot.fk_gripper(original_angles)
# Start IK from a different position
initial_guess = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
# Measure IK performance
start_time = time.time()
computed_angles = robot.ik(initial_guess.copy(), target_pose)
ik_time = time.time() - start_time
# Compute resulting pose from IK solution
result_pose = robot.fk_gripper(computed_angles)
# Calculate position error
pos_error = np.linalg.norm(target_pose[:3, 3] - result_pose[:3, 3])
passed = pos_error < 0.01 # Accept errors less than 1cm
print(f" IK computation time: {ik_time:.4f} seconds")
print(f" Position error: {pos_error:.4f}")
print(f" IK position accuracy: {'PASSED' if passed else 'FAILED'}")
return is_consistent and jacobian_shape_ok and pos_jacobian_shape_ok and passed
# Run tests for all robot types
results = {}
for robot_type in ["koch", "so100", "moss"]:
results[robot_type] = run_test(robot_type)
# Print overall summary
print("\n=== Test Summary ===")
all_passed = all(results.values())
for robot_type, passed in results.items():
print(f"{robot_type.upper()}: {'PASSED' if passed else 'FAILED'}")
print(f"\nOverall: {'ALL TESTS PASSED' if all_passed else 'SOME TESTS FAILED'}")

View File

@@ -5,7 +5,7 @@ import pytest
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.scripts.server.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized
from lerobot.common.utils.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized
from tests.fixtures.constants import DUMMY_REPO_ID