[Port HIl-Serl] Refactor gym-manipulator (#1034)

This commit is contained in:
Michel Aractingi
2025-04-25 16:34:54 +02:00
committed by GitHub
parent a8da4a347e
commit bd4db8d747
13 changed files with 624 additions and 946 deletions

View File

@@ -182,15 +182,15 @@ class EEActionSpaceConfig:
y_step_size: float y_step_size: float
z_step_size: float z_step_size: float
bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds
use_gamepad: bool = False control_mode: str = "gamepad"
@dataclass @dataclass
class EnvWrapperConfig: class EnvWrapperConfig:
"""Configuration for environment wrappers.""" """Configuration for environment wrappers."""
ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig)
display_cameras: bool = False display_cameras: bool = False
use_relative_joint_positions: bool = True
add_joint_velocity_to_observation: bool = False add_joint_velocity_to_observation: bool = False
add_current_to_observation: bool = False add_current_to_observation: bool = False
add_ee_pose_to_observation: bool = False add_ee_pose_to_observation: bool = False
@@ -199,13 +199,10 @@ class EnvWrapperConfig:
control_time_s: float = 20.0 control_time_s: float = 20.0
fixed_reset_joint_positions: Optional[Any] = None fixed_reset_joint_positions: Optional[Any] = None
reset_time_s: float = 5.0 reset_time_s: float = 5.0
joint_masking_action_space: Optional[Any] = None
ee_action_space_params: Optional[EEActionSpaceConfig] = None
use_gripper: bool = False use_gripper: bool = False
gripper_quantization_threshold: float | None = 0.8 gripper_quantization_threshold: float | None = 0.8
gripper_penalty: float = 0.0 gripper_penalty: float = 0.0
gripper_penalty_in_reward: bool = False gripper_penalty_in_reward: bool = False
open_gripper_on_reset: bool = False
@EnvConfig.register_subclass(name="gym_manipulator") @EnvConfig.register_subclass(name="gym_manipulator")

View File

@@ -308,13 +308,13 @@ def reset_environment(robot, events, reset_time_s, fps):
) )
def reset_follower_position(robot: Robot, target_position): def reset_follower_position(robot_arm, target_position):
current_position = robot.follower_arms["main"].read("Present_Position") current_position = robot_arm.read("Present_Position")
trajectory = torch.from_numpy( trajectory = torch.from_numpy(
np.linspace(current_position, target_position, 50) np.linspace(current_position, target_position, 50)
) # NOTE: 30 is just an arbitrary number ) # NOTE: 30 is just an arbitrary number
for pose in trajectory: for pose in trajectory:
robot.send_action(pose) robot_arm.write("Goal_Position", pose)
busy_wait(0.015) busy_wait(0.015)

View File

@@ -443,7 +443,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
leader_arms: dict[str, MotorsBusConfig] = field( leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: { default_factory=lambda: {
"main": FeetechMotorsBusConfig( "main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem58760433331", port="/dev/tty.usbmodem58760431091",
motors={ motors={
# name: (index, model) # name: (index, model)
"shoulder_pan": [1, "sts3215"], "shoulder_pan": [1, "sts3215"],
@@ -460,7 +460,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
follower_arms: dict[str, MotorsBusConfig] = field( follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: { default_factory=lambda: {
"main": FeetechMotorsBusConfig( "main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem58760431631", port="/dev/tty.usbmodem585A0076891",
motors={ motors={
# name: (index, model) # name: (index, model)
"shoulder_pan": [1, "sts3215"], "shoulder_pan": [1, "sts3215"],

View File

@@ -167,7 +167,7 @@ from lerobot.common.robot_devices.control_utils import (
warmup_record, warmup_record,
) )
from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config
from lerobot.common.robot_devices.utils import safe_disconnect from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
from lerobot.common.utils.utils import has_method, init_logging, log_say from lerobot.common.utils.utils import has_method, init_logging, log_say
from lerobot.configs import parser from lerobot.configs import parser
@@ -276,6 +276,7 @@ def record(
if not robot.is_connected: if not robot.is_connected:
robot.connect() robot.connect()
listener, events = init_keyboard_listener() listener, events = init_keyboard_listener()
# Execute a few seconds without recording to: # Execute a few seconds without recording to:
@@ -284,14 +285,7 @@ def record(
# 3. place the cameras windows on screen # 3. place the cameras windows on screen
enable_teleoperation = policy is None enable_teleoperation = policy is None
log_say("Warmup record", cfg.play_sounds) log_say("Warmup record", cfg.play_sounds)
warmup_record( warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_data, cfg.fps)
robot,
events,
enable_teleoperation,
cfg.warmup_time_s,
cfg.display_data,
cfg.fps,
)
if has_method(robot, "teleop_safety_stop"): if has_method(robot, "teleop_safety_stop"):
robot.teleop_safety_stop() robot.teleop_safety_stop()
@@ -356,6 +350,7 @@ def replay(
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode]) dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode])
actions = dataset.hf_dataset.select_columns("action") actions = dataset.hf_dataset.select_columns("action")
if not robot.is_connected: if not robot.is_connected:
robot.connect() robot.connect()
@@ -366,6 +361,9 @@ def replay(
action = actions[idx]["action"] action = actions[idx]["action"]
robot.send_action(action) robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t
busy_wait(1 / cfg.fps - dt_s)
dt_s = time.perf_counter() - start_episode_t dt_s = time.perf_counter() - start_episode_t
log_control_info(robot, dt_s, fps=cfg.fps) log_control_info(robot, dt_s, fps=cfg.fps)

View File

@@ -20,13 +20,11 @@ from functools import lru_cache
from queue import Empty from queue import Empty
from statistics import mean, quantiles from statistics import mean, quantiles
# from lerobot.scripts.eval import eval_policy
import grpc import grpc
import torch import torch
from torch import nn from torch import nn
from torch.multiprocessing import Event, Queue from torch.multiprocessing import Event, Queue
# TODO: Remove the import of maniskill
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.robot_devices.utils import busy_wait from lerobot.common.robot_devices.utils import busy_wait
@@ -39,20 +37,21 @@ from lerobot.common.utils.utils import (
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service
from lerobot.scripts.server.buffer import ( from lerobot.scripts.server.buffer import Transition
Transition,
bytes_to_state_dict,
move_state_dict_to_device,
move_transition_to_device,
python_object_to_bytes,
transitions_to_bytes,
)
from lerobot.scripts.server.gym_manipulator import make_robot_env from lerobot.scripts.server.gym_manipulator import make_robot_env
from lerobot.scripts.server.network_utils import ( from lerobot.scripts.server.network_utils import (
bytes_to_state_dict,
python_object_to_bytes,
receive_bytes_in_chunks, receive_bytes_in_chunks,
send_bytes_in_chunks, send_bytes_in_chunks,
transitions_to_bytes,
)
from lerobot.scripts.server.utils import (
get_last_item_from_queue,
move_state_dict_to_device,
move_transition_to_device,
setup_process_handlers,
) )
from lerobot.scripts.server.utils import get_last_item_from_queue, setup_process_handlers
ACTOR_SHUTDOWN_TIMEOUT = 30 ACTOR_SHUTDOWN_TIMEOUT = 30
@@ -134,21 +133,8 @@ def actor_cli(cfg: TrainPipelineConfig):
interactions_process.start() interactions_process.start()
receive_policy_process.start() receive_policy_process.start()
# HACK: FOR MANISKILL we do not have a reward classifier
# TODO: Remove this once we merge into main
reward_classifier = None
# if (
# cfg.env.reward_classifier["pretrained_path"] is not None
# and cfg.env.reward_classifier["config_path"] is not None
# ):
# reward_classifier = get_classifier(
# pretrained_path=cfg.env.reward_classifier["pretrained_path"],
# config_path=cfg.env.reward_classifier["config_path"],
# )
act_with_policy( act_with_policy(
cfg=cfg, cfg=cfg,
reward_classifier=reward_classifier,
shutdown_event=shutdown_event, shutdown_event=shutdown_event,
parameters_queue=parameters_queue, parameters_queue=parameters_queue,
transitions_queue=transitions_queue, transitions_queue=transitions_queue,
@@ -183,7 +169,6 @@ def actor_cli(cfg: TrainPipelineConfig):
def act_with_policy( def act_with_policy(
cfg: TrainPipelineConfig, cfg: TrainPipelineConfig,
reward_classifier: nn.Module,
shutdown_event: any, # Event, shutdown_event: any, # Event,
parameters_queue: Queue, parameters_queue: Queue,
transitions_queue: Queue, transitions_queue: Queue,
@@ -197,7 +182,6 @@ def act_with_policy(
Args: Args:
cfg: Configuration settings for the interaction process. cfg: Configuration settings for the interaction process.
reward_classifier: Reward classifier to use for the interaction process.
shutdown_event: Event to check if the process should shutdown. shutdown_event: Event to check if the process should shutdown.
parameters_queue: Queue to receive updated network parameters from the learner. parameters_queue: Queue to receive updated network parameters from the learner.
transitions_queue: Queue to send transitions to the learner. transitions_queue: Queue to send transitions to the learner.
@@ -262,16 +246,10 @@ def act_with_policy(
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step) log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy())
else: else:
# TODO (azouitine): Make a custom space for torch tensor
action = online_env.action_space.sample() action = online_env.action_space.sample()
next_obs, reward, done, truncated, info = online_env.step(action)
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box next_obs, reward, done, truncated, info = online_env.step(action)
action = (
torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0)
)
sum_reward_episode += float(reward) sum_reward_episode += float(reward)
# Increment total steps counter for intervention rate # Increment total steps counter for intervention rate
@@ -286,11 +264,6 @@ def act_with_policy(
# Increment intervention steps counter # Increment intervention steps counter
episode_intervention_steps += 1 episode_intervention_steps += 1
# Check for NaN values in observations
for key, tensor in obs.items():
if torch.isnan(tensor).any():
logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}")
list_transition_to_send_to_learner.append( list_transition_to_send_to_learner.append(
Transition( Transition(
state=obs, state=obs,

View File

@@ -15,26 +15,15 @@
# limitations under the License. # limitations under the License.
import functools import functools
import io
import pickle # nosec B403: Safe usage of pickle
from contextlib import suppress from contextlib import suppress
from typing import Any, Callable, Optional, Sequence, TypedDict from typing import Callable, Optional, Sequence, TypedDict
import torch import torch
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
from tqdm import tqdm from tqdm import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.scripts.server.utils import Transition
class Transition(TypedDict):
state: dict[str, torch.Tensor]
action: torch.Tensor
reward: float
next_state: dict[str, torch.Tensor]
done: bool
truncated: bool
complementary_info: dict[str, torch.Tensor | float | int] | None = None
class BatchTransition(TypedDict): class BatchTransition(TypedDict):
@@ -47,103 +36,6 @@ class BatchTransition(TypedDict):
complementary_info: dict[str, torch.Tensor | float | int] | None = None complementary_info: dict[str, torch.Tensor | float | int] | None = None
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
device = torch.device(device)
non_blocking = device.type == "cuda"
# Move state tensors to device
transition["state"] = {
key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items()
}
# Move action to device
transition["action"] = transition["action"].to(device, non_blocking=non_blocking)
# Move reward and done if they are tensors
if isinstance(transition["reward"], torch.Tensor):
transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking)
if isinstance(transition["done"], torch.Tensor):
transition["done"] = transition["done"].to(device, non_blocking=non_blocking)
if isinstance(transition["truncated"], torch.Tensor):
transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking)
# Move next_state tensors to device
transition["next_state"] = {
key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items()
}
# Move complementary_info tensors if present
if transition.get("complementary_info") is not None:
for key, val in transition["complementary_info"].items():
if isinstance(val, torch.Tensor):
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
elif isinstance(val, (int, float, bool)):
transition["complementary_info"][key] = torch.tensor(val, device=device)
else:
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
return transition
def move_state_dict_to_device(state_dict, device="cpu"):
"""
Recursively move all tensors in a (potentially) nested
dict/list/tuple structure to the CPU.
"""
if isinstance(state_dict, torch.Tensor):
return state_dict.to(device)
elif isinstance(state_dict, dict):
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
elif isinstance(state_dict, list):
return [move_state_dict_to_device(v, device=device) for v in state_dict]
elif isinstance(state_dict, tuple):
return tuple(move_state_dict_to_device(v, device=device) for v in state_dict)
else:
return state_dict
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
"""Convert model state dict to flat array for transmission"""
buffer = io.BytesIO()
torch.save(state_dict, buffer)
return buffer.getvalue()
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
buffer = io.BytesIO(buffer)
buffer.seek(0)
return torch.load(buffer) # nosec B614: Safe usage of torch.load
def python_object_to_bytes(python_object: Any) -> bytes:
return pickle.dumps(python_object)
def bytes_to_python_object(buffer: bytes) -> Any:
buffer = io.BytesIO(buffer)
buffer.seek(0)
obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load
# Add validation checks here
return obj
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
buffer = io.BytesIO(buffer)
buffer.seek(0)
transitions = torch.load(buffer) # nosec B614: Safe usage of torch.load
# Add validation checks here
return transitions
def transitions_to_bytes(transitions: list[Transition]) -> bytes:
buffer = io.BytesIO()
torch.save(transitions, buffer)
return buffer.getvalue()
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor: def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
""" """
Perform a per-image random crop over a batch of images in a vectorized way. Perform a per-image random crop over a batch of images in a vectorized way.
@@ -514,7 +406,6 @@ class ReplayBuffer:
device: str = "cuda:0", device: str = "cuda:0",
state_keys: Optional[Sequence[str]] = None, state_keys: Optional[Sequence[str]] = None,
capacity: Optional[int] = None, capacity: Optional[int] = None,
action_mask: Optional[Sequence[int]] = None,
image_augmentation_function: Optional[Callable] = None, image_augmentation_function: Optional[Callable] = None,
use_drq: bool = True, use_drq: bool = True,
storage_device: str = "cpu", storage_device: str = "cpu",
@@ -566,13 +457,6 @@ class ReplayBuffer:
first_state = {k: v.to(device) for k, v in first_transition["state"].items()} first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
first_action = first_transition["action"].to(device) first_action = first_transition["action"].to(device)
# Apply action mask/delta if needed
if action_mask is not None:
if first_action.dim() == 1:
first_action = first_action[action_mask]
else:
first_action = first_action[:, action_mask]
# Get complementary info if available # Get complementary info if available
first_complementary_info = None first_complementary_info = None
if ( if (
@@ -597,8 +481,6 @@ class ReplayBuffer:
data[k] = v.to(storage_device) data[k] = v.to(storage_device)
action = data["action"] action = data["action"]
if action_mask is not None:
action = action[action_mask] if action.dim() == 1 else action[:, action_mask]
replay_buffer.add( replay_buffer.add(
state=data["state"], state=data["state"],

View File

@@ -524,7 +524,9 @@ def teleoperate_inverse_kinematics_with_leader(robot, fps=10):
leader_ee = kinematics.fk_gripper_tip(leader_joint_positions) leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
desired_ee_pos = leader_ee desired_ee_pos = leader_ee
target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True) target_joint_state = kinematics.ik(
joint_positions, desired_ee_pos, position_only=True, fk_func=kinematics.fk_gripper_tip
)
robot.send_action(torch.from_numpy(target_joint_state)) robot.send_action(torch.from_numpy(target_joint_state))
logging.info(f"Leader EE: {leader_ee[:3, 3]}, Follower EE: {ee_pos[:3, 3]}") logging.info(f"Leader EE: {leader_ee[:3, 3]}, Follower EE: {ee_pos[:3, 3]}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time)) busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
@@ -544,6 +546,8 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
initial_leader_ee = kinematics.fk_gripper_tip(leader_joint_positions) initial_leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
desired_ee_pos = np.diag(np.ones(4)) desired_ee_pos = np.diag(np.ones(4))
joint_positions = robot.follower_arms["main"].read("Present_Position")
fixed_ee_pos = kinematics.fk_gripper_tip(joint_positions)
while time.perf_counter() - timestep < 60.0: while time.perf_counter() - timestep < 60.0:
loop_start_time = time.perf_counter() loop_start_time = time.perf_counter()
@@ -561,25 +565,26 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
# Calculate delta between leader and follower end-effectors # Calculate delta between leader and follower end-effectors
# Scaling factor can be adjusted for sensitivity # Scaling factor can be adjusted for sensitivity
scaling_factor = 1.0 scaling_factor = 1.0
ee_delta = (leader_ee - initial_leader_ee) * scaling_factor ee_delta = -np.clip((leader_ee - initial_leader_ee) * scaling_factor, -0.05, 0.05)
# Apply delta to current position # Apply delta to current position
desired_ee_pos[0, 3] = current_ee_pos[0, 3] + ee_delta[0, 3] desired_ee_pos[0, 3] = fixed_ee_pos[0, 3] # current_ee_pos[0, 3] + ee_delta[0, 3] * 0
desired_ee_pos[1, 3] = current_ee_pos[1, 3] + ee_delta[1, 3] desired_ee_pos[1, 3] = fixed_ee_pos[1, 3] # current_ee_pos[1, 3] + ee_delta[1, 3] * 0
desired_ee_pos[2, 3] = current_ee_pos[2, 3] + ee_delta[2, 3] desired_ee_pos[2, 3] = current_ee_pos[2, 3] - ee_delta[2, 3]
if np.any(np.abs(ee_delta[:3, 3]) > 0.01): # Compute joint targets via inverse kinematics
# Compute joint targets via inverse kinematics target_joint_state = kinematics.ik(
target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True) joint_positions, desired_ee_pos, position_only=True, fk_func=kinematics.fk_gripper_tip
)
initial_leader_ee = leader_ee.copy() initial_leader_ee = leader_ee.copy()
# Send command to robot # Send command to robot
robot.send_action(torch.from_numpy(target_joint_state)) robot.send_action(torch.from_numpy(target_joint_state))
# Logging # Logging
logging.info(f"Current EE: {current_ee_pos[:3, 3]}, Desired EE: {desired_ee_pos[:3, 3]}") logging.info(f"Current EE: {current_ee_pos[:3, 3]}, Desired EE: {desired_ee_pos[:3, 3]}")
logging.info(f"Delta EE: {ee_delta[:3, 3]}") logging.info(f"Delta EE: {ee_delta[:3, 3]}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time)) busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
@@ -715,8 +720,8 @@ if __name__ == "__main__":
"gamepad", "gamepad",
"keyboard_gym", "keyboard_gym",
"gamepad_gym", "gamepad_gym",
"leader_delta",
"leader", "leader",
"leader_abs",
], ],
help="Control mode to use", help="Control mode to use",
) )
@@ -768,11 +773,11 @@ if __name__ == "__main__":
env = make_robot_env(cfg, robot) env = make_robot_env(cfg, robot)
teleoperate_gym_env(env, controller, fps=cfg.fps) teleoperate_gym_env(env, controller, fps=cfg.fps)
elif args.mode == "leader": elif args.mode == "leader_delta":
# Leader-follower modes don't use controllers # Leader-follower modes don't use controllers
teleoperate_delta_inverse_kinematics_with_leader(robot) teleoperate_delta_inverse_kinematics_with_leader(robot)
elif args.mode == "leader_abs": elif args.mode == "leader":
teleoperate_inverse_kinematics_with_leader(robot) teleoperate_inverse_kinematics_with_leader(robot)
finally: finally:

View File

@@ -11,7 +11,7 @@ from lerobot.configs import parser
from lerobot.scripts.server.kinematics import RobotKinematics from lerobot.scripts.server.kinematics import RobotKinematics
follower_port = "/dev/tty.usbmodem58760431631" follower_port = "/dev/tty.usbmodem58760431631"
leader_port = "/dev/tty.usbmodem58760433331" leader_port = "/dev/tty.usbmodem585A0077921"
def find_joint_bounds( def find_joint_bounds(

File diff suppressed because it is too large Load Diff

View File

@@ -23,8 +23,6 @@ from pathlib import Path
from pprint import pformat from pprint import pformat
import grpc import grpc
# Import generated stubs
import hilserl_pb2_grpc # type: ignore import hilserl_pb2_grpc # type: ignore
import torch import torch
from termcolor import colored from termcolor import colored
@@ -39,8 +37,6 @@ from lerobot.common.constants import (
TRAINING_STATE_DIR, TRAINING_STATE_DIR,
) )
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
# TODO: Remove the import of maniskill
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.policies.sac.modeling_sac import SACPolicy
@@ -62,16 +58,17 @@ from lerobot.common.utils.wandb_utils import WandBLogger
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.server import learner_service from lerobot.scripts.server import learner_service
from lerobot.scripts.server.buffer import ( from lerobot.scripts.server.buffer import ReplayBuffer, concatenate_batch_transitions
ReplayBuffer, from lerobot.scripts.server.network_utils import (
bytes_to_python_object, bytes_to_python_object,
bytes_to_transitions, bytes_to_transitions,
concatenate_batch_transitions,
move_state_dict_to_device,
move_transition_to_device,
state_to_bytes, state_to_bytes,
) )
from lerobot.scripts.server.utils import setup_process_handlers from lerobot.scripts.server.utils import (
move_state_dict_to_device,
move_transition_to_device,
setup_process_handlers,
)
LOG_PREFIX = "[LEARNER]" LOG_PREFIX = "[LEARNER]"
@@ -307,17 +304,10 @@ def add_actor_information_and_train(
offline_replay_buffer = None offline_replay_buffer = None
if cfg.dataset is not None: if cfg.dataset is not None:
active_action_dims = None
# TODO: FIX THIS
if cfg.env.wrapper.joint_masking_action_space is not None:
active_action_dims = [
i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask
]
offline_replay_buffer = initialize_offline_replay_buffer( offline_replay_buffer = initialize_offline_replay_buffer(
cfg=cfg, cfg=cfg,
device=device, device=device,
storage_device=storage_device, storage_device=storage_device,
active_action_dims=active_action_dims,
) )
batch_size: int = batch_size // 2 # We will sample from both replay buffer batch_size: int = batch_size // 2 # We will sample from both replay buffer
@@ -342,7 +332,6 @@ def add_actor_information_and_train(
break break
# Process all available transitions to the replay buffer, send by the actor server # Process all available transitions to the replay buffer, send by the actor server
logging.debug("[LEARNER] Waiting for transitions")
process_transitions( process_transitions(
transition_queue=transition_queue, transition_queue=transition_queue,
replay_buffer=replay_buffer, replay_buffer=replay_buffer,
@@ -351,35 +340,29 @@ def add_actor_information_and_train(
dataset_repo_id=dataset_repo_id, dataset_repo_id=dataset_repo_id,
shutdown_event=shutdown_event, shutdown_event=shutdown_event,
) )
logging.debug("[LEARNER] Received transitions")
# Process all available interaction messages sent by the actor server # Process all available interaction messages sent by the actor server
logging.debug("[LEARNER] Waiting for interactions")
interaction_message = process_interaction_messages( interaction_message = process_interaction_messages(
interaction_message_queue=interaction_message_queue, interaction_message_queue=interaction_message_queue,
interaction_step_shift=interaction_step_shift, interaction_step_shift=interaction_step_shift,
wandb_logger=wandb_logger, wandb_logger=wandb_logger,
shutdown_event=shutdown_event, shutdown_event=shutdown_event,
) )
logging.debug("[LEARNER] Received interactions")
# Wait until the replay buffer has enough samples to start training # Wait until the replay buffer has enough samples to start training
if len(replay_buffer) < online_step_before_learning: if len(replay_buffer) < online_step_before_learning:
continue continue
if online_iterator is None: if online_iterator is None:
logging.debug("[LEARNER] Initializing online replay buffer iterator")
online_iterator = replay_buffer.get_iterator( online_iterator = replay_buffer.get_iterator(
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2 batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
) )
if offline_replay_buffer is not None and offline_iterator is None: if offline_replay_buffer is not None and offline_iterator is None:
logging.debug("[LEARNER] Initializing offline replay buffer iterator")
offline_iterator = offline_replay_buffer.get_iterator( offline_iterator = offline_replay_buffer.get_iterator(
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2 batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
) )
logging.debug("[LEARNER] Starting optimization loop")
time_for_one_optimization_step = time.time() time_for_one_optimization_step = time.time()
for _ in range(utd_ratio - 1): for _ in range(utd_ratio - 1):
# Sample from the iterators # Sample from the iterators
@@ -967,7 +950,6 @@ def initialize_offline_replay_buffer(
cfg: TrainPipelineConfig, cfg: TrainPipelineConfig,
device: str, device: str,
storage_device: str, storage_device: str,
active_action_dims: list[int] | None = None,
) -> ReplayBuffer: ) -> ReplayBuffer:
""" """
Initialize an offline replay buffer from a dataset. Initialize an offline replay buffer from a dataset.
@@ -976,7 +958,6 @@ def initialize_offline_replay_buffer(
cfg (TrainPipelineConfig): Training configuration cfg (TrainPipelineConfig): Training configuration
device (str): Device to store tensors on device (str): Device to store tensors on
storage_device (str): Device for storage optimization storage_device (str): Device for storage optimization
active_action_dims (list[int] | None): Active action dimensions for masking
Returns: Returns:
ReplayBuffer: Initialized offline replay buffer ReplayBuffer: Initialized offline replay buffer
@@ -997,7 +978,6 @@ def initialize_offline_replay_buffer(
offline_dataset, offline_dataset,
device=device, device=device,
state_keys=cfg.policy.input_features.keys(), state_keys=cfg.policy.input_features.keys(),
action_mask=active_action_dims,
storage_device=storage_device, storage_device=storage_device,
optimize_memory=True, optimize_memory=True,
capacity=cfg.policy.offline_buffer_capacity, capacity=cfg.policy.offline_buffer_capacity,
@@ -1096,44 +1076,6 @@ def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
parameters_queue.put(state_bytes) parameters_queue.put(state_bytes)
def check_weight_gradients(module: nn.Module) -> dict[str, bool]:
"""
Checks whether each parameter in the module has a gradient.
Args:
module (nn.Module): A PyTorch module whose parameters will be inspected.
Returns:
dict[str, bool]: A dictionary where each key is the parameter name and the value is
True if the parameter has an associated gradient (i.e. .grad is not None),
otherwise False.
"""
grad_status = {}
for name, param in module.named_parameters():
grad_status[name] = param.grad is not None
return grad_status
def get_overlapping_parameters(model: nn.Module, grad_status: dict[str, bool]) -> dict[str, bool]:
"""
Returns a dictionary of parameters (from actor) that also exist in the grad_status dictionary.
Args:
actor (nn.Module): The actor model.
grad_status (dict[str, bool]): A dictionary where keys are parameter names and values indicate
whether each parameter has a gradient.
Returns:
dict[str, bool]: A dictionary containing only the overlapping parameter names and their gradient status.
"""
# Get actor parameter names as a set.
model_param_names = {name for name, _ in model.named_parameters()}
# Intersect parameter names between actor and grad_status.
overlapping = {name: grad_status[name] for name in grad_status if name in model_param_names}
return overlapping
def process_interaction_message( def process_interaction_message(
message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None
): ):

View File

@@ -1,221 +0,0 @@
from typing import Any
import einops
import gymnasium as gym
import numpy as np
import torch
from mani_skill.utils.wrappers.record import RecordEpisode
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
from lerobot.common.envs.configs import ManiskillEnvConfig
def preprocess_maniskill_observation(
observations: dict[str, np.ndarray],
) -> dict[str, torch.Tensor]:
"""Convert environment observation to LeRobot format observation.
Args:
observation: Dictionary of observation batches from a Gym vector environment.
Returns:
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
"""
# map to expected inputs for the policy
return_observations = {}
# TODO: You have to merge all tensors from agent key and extra key
# You don't keep sensor param key in the observation
# And you keep sensor data rgb
q_pos = observations["agent"]["qpos"]
q_vel = observations["agent"]["qvel"]
tcp_pos = observations["extra"]["tcp_pose"]
img = observations["sensor_data"]["base_camera"]["rgb"]
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
img /= 255
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
return_observations["observation.image"] = img
return_observations["observation.state"] = state
return return_observations
class ManiSkillObservationWrapper(gym.ObservationWrapper):
def __init__(self, env, device: torch.device = "cuda"):
super().__init__(env)
if isinstance(device, str):
device = torch.device(device)
self.device = device
def observation(self, observation):
observation = preprocess_maniskill_observation(observation)
observation = {k: v.to(self.device) for k, v in observation.items()}
return observation
class ManiSkillCompat(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
new_action_space_shape = env.action_space.shape[-1]
new_low = np.squeeze(env.action_space.low, axis=0)
new_high = np.squeeze(env.action_space.high, axis=0)
self.action_space = gym.spaces.Box(low=new_low, high=new_high, shape=(new_action_space_shape,))
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[Any, dict[str, Any]]:
options = {}
return super().reset(seed=seed, options=options)
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
reward = reward.item()
terminated = terminated.item()
truncated = truncated.item()
return obs, reward, terminated, truncated, info
class ManiSkillActionWrapper(gym.ActionWrapper):
def __init__(self, env):
super().__init__(env)
self.action_space = gym.spaces.Tuple(spaces=(env.action_space, gym.spaces.Discrete(2)))
def action(self, action):
action, telop = action
return action
class ManiSkillMultiplyActionWrapper(gym.Wrapper):
def __init__(self, env, multiply_factor: float = 1):
super().__init__(env)
self.multiply_factor = multiply_factor
action_space_agent: gym.spaces.Box = env.action_space[0]
action_space_agent.low = action_space_agent.low * multiply_factor
action_space_agent.high = action_space_agent.high * multiply_factor
self.action_space = gym.spaces.Tuple(spaces=(action_space_agent, gym.spaces.Discrete(2)))
def step(self, action):
if isinstance(action, tuple):
action, telop = action
else:
telop = 0
action = action / self.multiply_factor
obs, reward, terminated, truncated, info = self.env.step((action, telop))
return obs, reward, terminated, truncated, info
class BatchCompatibleWrapper(gym.ObservationWrapper):
"""Ensures observations are batch-compatible by adding a batch dimension if necessary."""
def __init__(self, env):
super().__init__(env)
def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
for key in observation:
if "image" in key and observation[key].dim() == 3:
observation[key] = observation[key].unsqueeze(0)
if "state" in key and observation[key].dim() == 1:
observation[key] = observation[key].unsqueeze(0)
return observation
class TimeLimitWrapper(gym.Wrapper):
"""Adds a time limit to the environment based on fps and control_time."""
def __init__(self, env, control_time_s, fps):
super().__init__(env)
self.control_time_s = control_time_s
self.fps = fps
self.max_episode_steps = int(self.control_time_s * self.fps)
self.current_step = 0
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
self.current_step += 1
if self.current_step >= self.max_episode_steps:
terminated = True
return obs, reward, terminated, truncated, info
def reset(self, *, seed=None, options=None):
self.current_step = 0
return super().reset(seed=seed, options=options)
class ManiskillMockGripperWrapper(gym.Wrapper):
def __init__(self, env, nb_discrete_actions: int = 3):
super().__init__(env)
new_shape = env.action_space[0].shape[0] + 1
new_low = np.concatenate([env.action_space[0].low, [0]])
new_high = np.concatenate([env.action_space[0].high, [nb_discrete_actions - 1]])
action_space_agent = gym.spaces.Box(low=new_low, high=new_high, shape=(new_shape,))
self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1]))
def step(self, action):
if isinstance(action, tuple):
action_agent, telop_action = action
else:
telop_action = 0
action_agent = action
real_action = action_agent[:-1]
final_action = (real_action, telop_action)
obs, reward, terminated, truncated, info = self.env.step(final_action)
return obs, reward, terminated, truncated, info
def make_maniskill(
cfg: ManiskillEnvConfig,
n_envs: int | None = None,
) -> gym.Env:
"""
Factory function to create a ManiSkill environment with standard wrappers.
Args:
cfg: Configuration for the ManiSkill environment
n_envs: Number of parallel environments
Returns:
A wrapped ManiSkill environment
"""
env = gym.make(
cfg.task,
obs_mode=cfg.obs_type,
control_mode=cfg.control_mode,
render_mode=cfg.render_mode,
sensor_configs={"width": cfg.image_size, "height": cfg.image_size},
num_envs=n_envs,
)
# Add video recording if enabled
if cfg.video_record.enabled:
env = RecordEpisode(
env,
output_dir=cfg.video_record.record_dir,
save_trajectory=True,
trajectory_name=cfg.video_record.trajectory_name,
save_video=True,
video_fps=30,
)
# Add observation and image processing
env = ManiSkillObservationWrapper(env, device=cfg.device)
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
env._max_episode_steps = env.max_episode_steps = cfg.episode_length
env.unwrapped.metadata["render_fps"] = cfg.fps
# Add compatibility wrappers
env = ManiSkillCompat(env)
env = ManiSkillActionWrapper(env)
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control
if cfg.mock_gripper:
env = ManiskillMockGripperWrapper(env, nb_discrete_actions=3)
return env

View File

@@ -17,10 +17,14 @@
import io import io
import logging import logging
import pickle # nosec B403: Safe usage for internal serialization only
from multiprocessing import Event, Queue from multiprocessing import Event, Queue
from typing import Any from typing import Any
import torch
from lerobot.scripts.server import hilserl_pb2 from lerobot.scripts.server import hilserl_pb2
from lerobot.scripts.server.utils import Transition
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
@@ -60,7 +64,7 @@ def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = ""
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB") logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""): def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""): # type: ignore
bytes_buffer = io.BytesIO() bytes_buffer = io.BytesIO()
step = 0 step = 0
@@ -93,3 +97,44 @@ def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_p
step = 0 step = 0
logging.debug(f"{log_prefix} Queue updated") logging.debug(f"{log_prefix} Queue updated")
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
"""Convert model state dict to flat array for transmission"""
buffer = io.BytesIO()
torch.save(state_dict, buffer)
return buffer.getvalue()
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
buffer = io.BytesIO(buffer)
buffer.seek(0)
return torch.load(buffer) # nosec B614: Safe usage of torch.load
def python_object_to_bytes(python_object: Any) -> bytes:
return pickle.dumps(python_object)
def bytes_to_python_object(buffer: bytes) -> Any:
buffer = io.BytesIO(buffer)
buffer.seek(0)
obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load
# Add validation checks here
return obj
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
buffer = io.BytesIO(buffer)
buffer.seek(0)
transitions = torch.load(buffer) # nosec B614: Safe usage of torch.load
# Add validation checks here
return transitions
def transitions_to_bytes(transitions: list[Transition]) -> bytes:
buffer = io.BytesIO()
torch.save(transitions, buffer)
return buffer.getvalue()

View File

@@ -19,7 +19,9 @@ import logging
import signal import signal
import sys import sys
from queue import Empty from queue import Empty
from typing import TypedDict
import torch
from torch.multiprocessing import Queue from torch.multiprocessing import Queue
shutdown_event_counter = 0 shutdown_event_counter = 0
@@ -71,3 +73,69 @@ def get_last_item_from_queue(queue: Queue):
logging.debug(f"Drained {counter} items from queue") logging.debug(f"Drained {counter} items from queue")
return item return item
class Transition(TypedDict):
state: dict[str, torch.Tensor]
action: torch.Tensor
reward: float
next_state: dict[str, torch.Tensor]
done: bool
truncated: bool
complementary_info: dict[str, torch.Tensor | float | int] | None = None
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
device = torch.device(device)
non_blocking = device.type == "cuda"
# Move state tensors to device
transition["state"] = {
key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items()
}
# Move action to device
transition["action"] = transition["action"].to(device, non_blocking=non_blocking)
# Move reward and done if they are tensors
if isinstance(transition["reward"], torch.Tensor):
transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking)
if isinstance(transition["done"], torch.Tensor):
transition["done"] = transition["done"].to(device, non_blocking=non_blocking)
if isinstance(transition["truncated"], torch.Tensor):
transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking)
# Move next_state tensors to device
transition["next_state"] = {
key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items()
}
# Move complementary_info tensors if present
if transition.get("complementary_info") is not None:
for key, val in transition["complementary_info"].items():
if isinstance(val, torch.Tensor):
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
elif isinstance(val, (int, float, bool)):
transition["complementary_info"][key] = torch.tensor(val, device=device)
else:
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
return transition
def move_state_dict_to_device(state_dict, device="cpu"):
"""
Recursively move all tensors in a (potentially) nested
dict/list/tuple structure to the CPU.
"""
if isinstance(state_dict, torch.Tensor):
return state_dict.to(device)
elif isinstance(state_dict, dict):
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
elif isinstance(state_dict, list):
return [move_state_dict_to_device(v, device=device) for v in state_dict]
elif isinstance(state_dict, tuple):
return tuple(move_state_dict_to_device(v, device=device) for v in state_dict)
else:
return state_dict