[Port HIl-Serl] Refactor gym-manipulator (#1034)

This commit is contained in:
Michel Aractingi
2025-04-25 16:34:54 +02:00
committed by GitHub
parent a8da4a347e
commit bd4db8d747
13 changed files with 624 additions and 946 deletions

View File

@@ -182,15 +182,15 @@ class EEActionSpaceConfig:
y_step_size: float
z_step_size: float
bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds
use_gamepad: bool = False
control_mode: str = "gamepad"
@dataclass
class EnvWrapperConfig:
"""Configuration for environment wrappers."""
ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig)
display_cameras: bool = False
use_relative_joint_positions: bool = True
add_joint_velocity_to_observation: bool = False
add_current_to_observation: bool = False
add_ee_pose_to_observation: bool = False
@@ -199,13 +199,10 @@ class EnvWrapperConfig:
control_time_s: float = 20.0
fixed_reset_joint_positions: Optional[Any] = None
reset_time_s: float = 5.0
joint_masking_action_space: Optional[Any] = None
ee_action_space_params: Optional[EEActionSpaceConfig] = None
use_gripper: bool = False
gripper_quantization_threshold: float | None = 0.8
gripper_penalty: float = 0.0
gripper_penalty_in_reward: bool = False
open_gripper_on_reset: bool = False
@EnvConfig.register_subclass(name="gym_manipulator")

View File

@@ -308,13 +308,13 @@ def reset_environment(robot, events, reset_time_s, fps):
)
def reset_follower_position(robot: Robot, target_position):
current_position = robot.follower_arms["main"].read("Present_Position")
def reset_follower_position(robot_arm, target_position):
current_position = robot_arm.read("Present_Position")
trajectory = torch.from_numpy(
np.linspace(current_position, target_position, 50)
) # NOTE: 30 is just an arbitrary number
for pose in trajectory:
robot.send_action(pose)
robot_arm.write("Goal_Position", pose)
busy_wait(0.015)

View File

@@ -443,7 +443,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem58760433331",
port="/dev/tty.usbmodem58760431091",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
@@ -460,7 +460,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem58760431631",
port="/dev/tty.usbmodem585A0076891",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],

View File

@@ -167,7 +167,7 @@ from lerobot.common.robot_devices.control_utils import (
warmup_record,
)
from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config
from lerobot.common.robot_devices.utils import safe_disconnect
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
from lerobot.common.utils.utils import has_method, init_logging, log_say
from lerobot.configs import parser
@@ -276,6 +276,7 @@ def record(
if not robot.is_connected:
robot.connect()
listener, events = init_keyboard_listener()
# Execute a few seconds without recording to:
@@ -284,14 +285,7 @@ def record(
# 3. place the cameras windows on screen
enable_teleoperation = policy is None
log_say("Warmup record", cfg.play_sounds)
warmup_record(
robot,
events,
enable_teleoperation,
cfg.warmup_time_s,
cfg.display_data,
cfg.fps,
)
warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_data, cfg.fps)
if has_method(robot, "teleop_safety_stop"):
robot.teleop_safety_stop()
@@ -356,6 +350,7 @@ def replay(
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode])
actions = dataset.hf_dataset.select_columns("action")
if not robot.is_connected:
robot.connect()
@@ -366,6 +361,9 @@ def replay(
action = actions[idx]["action"]
robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t
busy_wait(1 / cfg.fps - dt_s)
dt_s = time.perf_counter() - start_episode_t
log_control_info(robot, dt_s, fps=cfg.fps)

View File

@@ -20,13 +20,11 @@ from functools import lru_cache
from queue import Empty
from statistics import mean, quantiles
# from lerobot.scripts.eval import eval_policy
import grpc
import torch
from torch import nn
from torch.multiprocessing import Event, Queue
# TODO: Remove the import of maniskill
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.robot_devices.utils import busy_wait
@@ -39,20 +37,21 @@ from lerobot.common.utils.utils import (
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service
from lerobot.scripts.server.buffer import (
Transition,
bytes_to_state_dict,
move_state_dict_to_device,
move_transition_to_device,
python_object_to_bytes,
transitions_to_bytes,
)
from lerobot.scripts.server.buffer import Transition
from lerobot.scripts.server.gym_manipulator import make_robot_env
from lerobot.scripts.server.network_utils import (
bytes_to_state_dict,
python_object_to_bytes,
receive_bytes_in_chunks,
send_bytes_in_chunks,
transitions_to_bytes,
)
from lerobot.scripts.server.utils import (
get_last_item_from_queue,
move_state_dict_to_device,
move_transition_to_device,
setup_process_handlers,
)
from lerobot.scripts.server.utils import get_last_item_from_queue, setup_process_handlers
ACTOR_SHUTDOWN_TIMEOUT = 30
@@ -134,21 +133,8 @@ def actor_cli(cfg: TrainPipelineConfig):
interactions_process.start()
receive_policy_process.start()
# HACK: FOR MANISKILL we do not have a reward classifier
# TODO: Remove this once we merge into main
reward_classifier = None
# if (
# cfg.env.reward_classifier["pretrained_path"] is not None
# and cfg.env.reward_classifier["config_path"] is not None
# ):
# reward_classifier = get_classifier(
# pretrained_path=cfg.env.reward_classifier["pretrained_path"],
# config_path=cfg.env.reward_classifier["config_path"],
# )
act_with_policy(
cfg=cfg,
reward_classifier=reward_classifier,
shutdown_event=shutdown_event,
parameters_queue=parameters_queue,
transitions_queue=transitions_queue,
@@ -183,7 +169,6 @@ def actor_cli(cfg: TrainPipelineConfig):
def act_with_policy(
cfg: TrainPipelineConfig,
reward_classifier: nn.Module,
shutdown_event: any, # Event,
parameters_queue: Queue,
transitions_queue: Queue,
@@ -197,7 +182,6 @@ def act_with_policy(
Args:
cfg: Configuration settings for the interaction process.
reward_classifier: Reward classifier to use for the interaction process.
shutdown_event: Event to check if the process should shutdown.
parameters_queue: Queue to receive updated network parameters from the learner.
transitions_queue: Queue to send transitions to the learner.
@@ -262,16 +246,10 @@ def act_with_policy(
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy())
else:
# TODO (azouitine): Make a custom space for torch tensor
action = online_env.action_space.sample()
next_obs, reward, done, truncated, info = online_env.step(action)
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
action = (
torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0)
)
next_obs, reward, done, truncated, info = online_env.step(action)
sum_reward_episode += float(reward)
# Increment total steps counter for intervention rate
@@ -286,11 +264,6 @@ def act_with_policy(
# Increment intervention steps counter
episode_intervention_steps += 1
# Check for NaN values in observations
for key, tensor in obs.items():
if torch.isnan(tensor).any():
logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}")
list_transition_to_send_to_learner.append(
Transition(
state=obs,

View File

@@ -15,26 +15,15 @@
# limitations under the License.
import functools
import io
import pickle # nosec B403: Safe usage of pickle
from contextlib import suppress
from typing import Any, Callable, Optional, Sequence, TypedDict
from typing import Callable, Optional, Sequence, TypedDict
import torch
import torch.nn.functional as F # noqa: N812
from tqdm import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
class Transition(TypedDict):
state: dict[str, torch.Tensor]
action: torch.Tensor
reward: float
next_state: dict[str, torch.Tensor]
done: bool
truncated: bool
complementary_info: dict[str, torch.Tensor | float | int] | None = None
from lerobot.scripts.server.utils import Transition
class BatchTransition(TypedDict):
@@ -47,103 +36,6 @@ class BatchTransition(TypedDict):
complementary_info: dict[str, torch.Tensor | float | int] | None = None
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
device = torch.device(device)
non_blocking = device.type == "cuda"
# Move state tensors to device
transition["state"] = {
key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items()
}
# Move action to device
transition["action"] = transition["action"].to(device, non_blocking=non_blocking)
# Move reward and done if they are tensors
if isinstance(transition["reward"], torch.Tensor):
transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking)
if isinstance(transition["done"], torch.Tensor):
transition["done"] = transition["done"].to(device, non_blocking=non_blocking)
if isinstance(transition["truncated"], torch.Tensor):
transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking)
# Move next_state tensors to device
transition["next_state"] = {
key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items()
}
# Move complementary_info tensors if present
if transition.get("complementary_info") is not None:
for key, val in transition["complementary_info"].items():
if isinstance(val, torch.Tensor):
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
elif isinstance(val, (int, float, bool)):
transition["complementary_info"][key] = torch.tensor(val, device=device)
else:
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
return transition
def move_state_dict_to_device(state_dict, device="cpu"):
"""
Recursively move all tensors in a (potentially) nested
dict/list/tuple structure to the CPU.
"""
if isinstance(state_dict, torch.Tensor):
return state_dict.to(device)
elif isinstance(state_dict, dict):
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
elif isinstance(state_dict, list):
return [move_state_dict_to_device(v, device=device) for v in state_dict]
elif isinstance(state_dict, tuple):
return tuple(move_state_dict_to_device(v, device=device) for v in state_dict)
else:
return state_dict
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
"""Convert model state dict to flat array for transmission"""
buffer = io.BytesIO()
torch.save(state_dict, buffer)
return buffer.getvalue()
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
buffer = io.BytesIO(buffer)
buffer.seek(0)
return torch.load(buffer) # nosec B614: Safe usage of torch.load
def python_object_to_bytes(python_object: Any) -> bytes:
return pickle.dumps(python_object)
def bytes_to_python_object(buffer: bytes) -> Any:
buffer = io.BytesIO(buffer)
buffer.seek(0)
obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load
# Add validation checks here
return obj
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
buffer = io.BytesIO(buffer)
buffer.seek(0)
transitions = torch.load(buffer) # nosec B614: Safe usage of torch.load
# Add validation checks here
return transitions
def transitions_to_bytes(transitions: list[Transition]) -> bytes:
buffer = io.BytesIO()
torch.save(transitions, buffer)
return buffer.getvalue()
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
"""
Perform a per-image random crop over a batch of images in a vectorized way.
@@ -514,7 +406,6 @@ class ReplayBuffer:
device: str = "cuda:0",
state_keys: Optional[Sequence[str]] = None,
capacity: Optional[int] = None,
action_mask: Optional[Sequence[int]] = None,
image_augmentation_function: Optional[Callable] = None,
use_drq: bool = True,
storage_device: str = "cpu",
@@ -566,13 +457,6 @@ class ReplayBuffer:
first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
first_action = first_transition["action"].to(device)
# Apply action mask/delta if needed
if action_mask is not None:
if first_action.dim() == 1:
first_action = first_action[action_mask]
else:
first_action = first_action[:, action_mask]
# Get complementary info if available
first_complementary_info = None
if (
@@ -597,8 +481,6 @@ class ReplayBuffer:
data[k] = v.to(storage_device)
action = data["action"]
if action_mask is not None:
action = action[action_mask] if action.dim() == 1 else action[:, action_mask]
replay_buffer.add(
state=data["state"],

View File

@@ -524,7 +524,9 @@ def teleoperate_inverse_kinematics_with_leader(robot, fps=10):
leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
desired_ee_pos = leader_ee
target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True)
target_joint_state = kinematics.ik(
joint_positions, desired_ee_pos, position_only=True, fk_func=kinematics.fk_gripper_tip
)
robot.send_action(torch.from_numpy(target_joint_state))
logging.info(f"Leader EE: {leader_ee[:3, 3]}, Follower EE: {ee_pos[:3, 3]}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
@@ -544,6 +546,8 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
initial_leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
desired_ee_pos = np.diag(np.ones(4))
joint_positions = robot.follower_arms["main"].read("Present_Position")
fixed_ee_pos = kinematics.fk_gripper_tip(joint_positions)
while time.perf_counter() - timestep < 60.0:
loop_start_time = time.perf_counter()
@@ -561,25 +565,26 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
# Calculate delta between leader and follower end-effectors
# Scaling factor can be adjusted for sensitivity
scaling_factor = 1.0
ee_delta = (leader_ee - initial_leader_ee) * scaling_factor
ee_delta = -np.clip((leader_ee - initial_leader_ee) * scaling_factor, -0.05, 0.05)
# Apply delta to current position
desired_ee_pos[0, 3] = current_ee_pos[0, 3] + ee_delta[0, 3]
desired_ee_pos[1, 3] = current_ee_pos[1, 3] + ee_delta[1, 3]
desired_ee_pos[2, 3] = current_ee_pos[2, 3] + ee_delta[2, 3]
desired_ee_pos[0, 3] = fixed_ee_pos[0, 3] # current_ee_pos[0, 3] + ee_delta[0, 3] * 0
desired_ee_pos[1, 3] = fixed_ee_pos[1, 3] # current_ee_pos[1, 3] + ee_delta[1, 3] * 0
desired_ee_pos[2, 3] = current_ee_pos[2, 3] - ee_delta[2, 3]
if np.any(np.abs(ee_delta[:3, 3]) > 0.01):
# Compute joint targets via inverse kinematics
target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True)
# Compute joint targets via inverse kinematics
target_joint_state = kinematics.ik(
joint_positions, desired_ee_pos, position_only=True, fk_func=kinematics.fk_gripper_tip
)
initial_leader_ee = leader_ee.copy()
initial_leader_ee = leader_ee.copy()
# Send command to robot
robot.send_action(torch.from_numpy(target_joint_state))
# Send command to robot
robot.send_action(torch.from_numpy(target_joint_state))
# Logging
logging.info(f"Current EE: {current_ee_pos[:3, 3]}, Desired EE: {desired_ee_pos[:3, 3]}")
logging.info(f"Delta EE: {ee_delta[:3, 3]}")
# Logging
logging.info(f"Current EE: {current_ee_pos[:3, 3]}, Desired EE: {desired_ee_pos[:3, 3]}")
logging.info(f"Delta EE: {ee_delta[:3, 3]}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
@@ -715,8 +720,8 @@ if __name__ == "__main__":
"gamepad",
"keyboard_gym",
"gamepad_gym",
"leader_delta",
"leader",
"leader_abs",
],
help="Control mode to use",
)
@@ -768,11 +773,11 @@ if __name__ == "__main__":
env = make_robot_env(cfg, robot)
teleoperate_gym_env(env, controller, fps=cfg.fps)
elif args.mode == "leader":
elif args.mode == "leader_delta":
# Leader-follower modes don't use controllers
teleoperate_delta_inverse_kinematics_with_leader(robot)
elif args.mode == "leader_abs":
elif args.mode == "leader":
teleoperate_inverse_kinematics_with_leader(robot)
finally:

View File

@@ -11,7 +11,7 @@ from lerobot.configs import parser
from lerobot.scripts.server.kinematics import RobotKinematics
follower_port = "/dev/tty.usbmodem58760431631"
leader_port = "/dev/tty.usbmodem58760433331"
leader_port = "/dev/tty.usbmodem585A0077921"
def find_joint_bounds(

File diff suppressed because it is too large Load Diff

View File

@@ -23,8 +23,6 @@ from pathlib import Path
from pprint import pformat
import grpc
# Import generated stubs
import hilserl_pb2_grpc # type: ignore
import torch
from termcolor import colored
@@ -39,8 +37,6 @@ from lerobot.common.constants import (
TRAINING_STATE_DIR,
)
from lerobot.common.datasets.factory import make_dataset
# TODO: Remove the import of maniskill
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
@@ -62,16 +58,17 @@ from lerobot.common.utils.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.server import learner_service
from lerobot.scripts.server.buffer import (
ReplayBuffer,
from lerobot.scripts.server.buffer import ReplayBuffer, concatenate_batch_transitions
from lerobot.scripts.server.network_utils import (
bytes_to_python_object,
bytes_to_transitions,
concatenate_batch_transitions,
move_state_dict_to_device,
move_transition_to_device,
state_to_bytes,
)
from lerobot.scripts.server.utils import setup_process_handlers
from lerobot.scripts.server.utils import (
move_state_dict_to_device,
move_transition_to_device,
setup_process_handlers,
)
LOG_PREFIX = "[LEARNER]"
@@ -307,17 +304,10 @@ def add_actor_information_and_train(
offline_replay_buffer = None
if cfg.dataset is not None:
active_action_dims = None
# TODO: FIX THIS
if cfg.env.wrapper.joint_masking_action_space is not None:
active_action_dims = [
i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask
]
offline_replay_buffer = initialize_offline_replay_buffer(
cfg=cfg,
device=device,
storage_device=storage_device,
active_action_dims=active_action_dims,
)
batch_size: int = batch_size // 2 # We will sample from both replay buffer
@@ -342,7 +332,6 @@ def add_actor_information_and_train(
break
# Process all available transitions to the replay buffer, send by the actor server
logging.debug("[LEARNER] Waiting for transitions")
process_transitions(
transition_queue=transition_queue,
replay_buffer=replay_buffer,
@@ -351,35 +340,29 @@ def add_actor_information_and_train(
dataset_repo_id=dataset_repo_id,
shutdown_event=shutdown_event,
)
logging.debug("[LEARNER] Received transitions")
# Process all available interaction messages sent by the actor server
logging.debug("[LEARNER] Waiting for interactions")
interaction_message = process_interaction_messages(
interaction_message_queue=interaction_message_queue,
interaction_step_shift=interaction_step_shift,
wandb_logger=wandb_logger,
shutdown_event=shutdown_event,
)
logging.debug("[LEARNER] Received interactions")
# Wait until the replay buffer has enough samples to start training
if len(replay_buffer) < online_step_before_learning:
continue
if online_iterator is None:
logging.debug("[LEARNER] Initializing online replay buffer iterator")
online_iterator = replay_buffer.get_iterator(
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
)
if offline_replay_buffer is not None and offline_iterator is None:
logging.debug("[LEARNER] Initializing offline replay buffer iterator")
offline_iterator = offline_replay_buffer.get_iterator(
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
)
logging.debug("[LEARNER] Starting optimization loop")
time_for_one_optimization_step = time.time()
for _ in range(utd_ratio - 1):
# Sample from the iterators
@@ -967,7 +950,6 @@ def initialize_offline_replay_buffer(
cfg: TrainPipelineConfig,
device: str,
storage_device: str,
active_action_dims: list[int] | None = None,
) -> ReplayBuffer:
"""
Initialize an offline replay buffer from a dataset.
@@ -976,7 +958,6 @@ def initialize_offline_replay_buffer(
cfg (TrainPipelineConfig): Training configuration
device (str): Device to store tensors on
storage_device (str): Device for storage optimization
active_action_dims (list[int] | None): Active action dimensions for masking
Returns:
ReplayBuffer: Initialized offline replay buffer
@@ -997,7 +978,6 @@ def initialize_offline_replay_buffer(
offline_dataset,
device=device,
state_keys=cfg.policy.input_features.keys(),
action_mask=active_action_dims,
storage_device=storage_device,
optimize_memory=True,
capacity=cfg.policy.offline_buffer_capacity,
@@ -1096,44 +1076,6 @@ def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
parameters_queue.put(state_bytes)
def check_weight_gradients(module: nn.Module) -> dict[str, bool]:
"""
Checks whether each parameter in the module has a gradient.
Args:
module (nn.Module): A PyTorch module whose parameters will be inspected.
Returns:
dict[str, bool]: A dictionary where each key is the parameter name and the value is
True if the parameter has an associated gradient (i.e. .grad is not None),
otherwise False.
"""
grad_status = {}
for name, param in module.named_parameters():
grad_status[name] = param.grad is not None
return grad_status
def get_overlapping_parameters(model: nn.Module, grad_status: dict[str, bool]) -> dict[str, bool]:
"""
Returns a dictionary of parameters (from actor) that also exist in the grad_status dictionary.
Args:
actor (nn.Module): The actor model.
grad_status (dict[str, bool]): A dictionary where keys are parameter names and values indicate
whether each parameter has a gradient.
Returns:
dict[str, bool]: A dictionary containing only the overlapping parameter names and their gradient status.
"""
# Get actor parameter names as a set.
model_param_names = {name for name, _ in model.named_parameters()}
# Intersect parameter names between actor and grad_status.
overlapping = {name: grad_status[name] for name in grad_status if name in model_param_names}
return overlapping
def process_interaction_message(
message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None
):

View File

@@ -1,221 +0,0 @@
from typing import Any
import einops
import gymnasium as gym
import numpy as np
import torch
from mani_skill.utils.wrappers.record import RecordEpisode
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
from lerobot.common.envs.configs import ManiskillEnvConfig
def preprocess_maniskill_observation(
observations: dict[str, np.ndarray],
) -> dict[str, torch.Tensor]:
"""Convert environment observation to LeRobot format observation.
Args:
observation: Dictionary of observation batches from a Gym vector environment.
Returns:
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
"""
# map to expected inputs for the policy
return_observations = {}
# TODO: You have to merge all tensors from agent key and extra key
# You don't keep sensor param key in the observation
# And you keep sensor data rgb
q_pos = observations["agent"]["qpos"]
q_vel = observations["agent"]["qvel"]
tcp_pos = observations["extra"]["tcp_pose"]
img = observations["sensor_data"]["base_camera"]["rgb"]
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
img /= 255
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
return_observations["observation.image"] = img
return_observations["observation.state"] = state
return return_observations
class ManiSkillObservationWrapper(gym.ObservationWrapper):
def __init__(self, env, device: torch.device = "cuda"):
super().__init__(env)
if isinstance(device, str):
device = torch.device(device)
self.device = device
def observation(self, observation):
observation = preprocess_maniskill_observation(observation)
observation = {k: v.to(self.device) for k, v in observation.items()}
return observation
class ManiSkillCompat(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
new_action_space_shape = env.action_space.shape[-1]
new_low = np.squeeze(env.action_space.low, axis=0)
new_high = np.squeeze(env.action_space.high, axis=0)
self.action_space = gym.spaces.Box(low=new_low, high=new_high, shape=(new_action_space_shape,))
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[Any, dict[str, Any]]:
options = {}
return super().reset(seed=seed, options=options)
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
reward = reward.item()
terminated = terminated.item()
truncated = truncated.item()
return obs, reward, terminated, truncated, info
class ManiSkillActionWrapper(gym.ActionWrapper):
def __init__(self, env):
super().__init__(env)
self.action_space = gym.spaces.Tuple(spaces=(env.action_space, gym.spaces.Discrete(2)))
def action(self, action):
action, telop = action
return action
class ManiSkillMultiplyActionWrapper(gym.Wrapper):
def __init__(self, env, multiply_factor: float = 1):
super().__init__(env)
self.multiply_factor = multiply_factor
action_space_agent: gym.spaces.Box = env.action_space[0]
action_space_agent.low = action_space_agent.low * multiply_factor
action_space_agent.high = action_space_agent.high * multiply_factor
self.action_space = gym.spaces.Tuple(spaces=(action_space_agent, gym.spaces.Discrete(2)))
def step(self, action):
if isinstance(action, tuple):
action, telop = action
else:
telop = 0
action = action / self.multiply_factor
obs, reward, terminated, truncated, info = self.env.step((action, telop))
return obs, reward, terminated, truncated, info
class BatchCompatibleWrapper(gym.ObservationWrapper):
"""Ensures observations are batch-compatible by adding a batch dimension if necessary."""
def __init__(self, env):
super().__init__(env)
def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
for key in observation:
if "image" in key and observation[key].dim() == 3:
observation[key] = observation[key].unsqueeze(0)
if "state" in key and observation[key].dim() == 1:
observation[key] = observation[key].unsqueeze(0)
return observation
class TimeLimitWrapper(gym.Wrapper):
"""Adds a time limit to the environment based on fps and control_time."""
def __init__(self, env, control_time_s, fps):
super().__init__(env)
self.control_time_s = control_time_s
self.fps = fps
self.max_episode_steps = int(self.control_time_s * self.fps)
self.current_step = 0
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
self.current_step += 1
if self.current_step >= self.max_episode_steps:
terminated = True
return obs, reward, terminated, truncated, info
def reset(self, *, seed=None, options=None):
self.current_step = 0
return super().reset(seed=seed, options=options)
class ManiskillMockGripperWrapper(gym.Wrapper):
def __init__(self, env, nb_discrete_actions: int = 3):
super().__init__(env)
new_shape = env.action_space[0].shape[0] + 1
new_low = np.concatenate([env.action_space[0].low, [0]])
new_high = np.concatenate([env.action_space[0].high, [nb_discrete_actions - 1]])
action_space_agent = gym.spaces.Box(low=new_low, high=new_high, shape=(new_shape,))
self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1]))
def step(self, action):
if isinstance(action, tuple):
action_agent, telop_action = action
else:
telop_action = 0
action_agent = action
real_action = action_agent[:-1]
final_action = (real_action, telop_action)
obs, reward, terminated, truncated, info = self.env.step(final_action)
return obs, reward, terminated, truncated, info
def make_maniskill(
cfg: ManiskillEnvConfig,
n_envs: int | None = None,
) -> gym.Env:
"""
Factory function to create a ManiSkill environment with standard wrappers.
Args:
cfg: Configuration for the ManiSkill environment
n_envs: Number of parallel environments
Returns:
A wrapped ManiSkill environment
"""
env = gym.make(
cfg.task,
obs_mode=cfg.obs_type,
control_mode=cfg.control_mode,
render_mode=cfg.render_mode,
sensor_configs={"width": cfg.image_size, "height": cfg.image_size},
num_envs=n_envs,
)
# Add video recording if enabled
if cfg.video_record.enabled:
env = RecordEpisode(
env,
output_dir=cfg.video_record.record_dir,
save_trajectory=True,
trajectory_name=cfg.video_record.trajectory_name,
save_video=True,
video_fps=30,
)
# Add observation and image processing
env = ManiSkillObservationWrapper(env, device=cfg.device)
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
env._max_episode_steps = env.max_episode_steps = cfg.episode_length
env.unwrapped.metadata["render_fps"] = cfg.fps
# Add compatibility wrappers
env = ManiSkillCompat(env)
env = ManiSkillActionWrapper(env)
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control
if cfg.mock_gripper:
env = ManiskillMockGripperWrapper(env, nb_discrete_actions=3)
return env

View File

@@ -17,10 +17,14 @@
import io
import logging
import pickle # nosec B403: Safe usage for internal serialization only
from multiprocessing import Event, Queue
from typing import Any
import torch
from lerobot.scripts.server import hilserl_pb2
from lerobot.scripts.server.utils import Transition
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
@@ -60,7 +64,7 @@ def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = ""
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""):
def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""): # type: ignore
bytes_buffer = io.BytesIO()
step = 0
@@ -93,3 +97,44 @@ def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_p
step = 0
logging.debug(f"{log_prefix} Queue updated")
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
"""Convert model state dict to flat array for transmission"""
buffer = io.BytesIO()
torch.save(state_dict, buffer)
return buffer.getvalue()
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
buffer = io.BytesIO(buffer)
buffer.seek(0)
return torch.load(buffer) # nosec B614: Safe usage of torch.load
def python_object_to_bytes(python_object: Any) -> bytes:
return pickle.dumps(python_object)
def bytes_to_python_object(buffer: bytes) -> Any:
buffer = io.BytesIO(buffer)
buffer.seek(0)
obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load
# Add validation checks here
return obj
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
buffer = io.BytesIO(buffer)
buffer.seek(0)
transitions = torch.load(buffer) # nosec B614: Safe usage of torch.load
# Add validation checks here
return transitions
def transitions_to_bytes(transitions: list[Transition]) -> bytes:
buffer = io.BytesIO()
torch.save(transitions, buffer)
return buffer.getvalue()

View File

@@ -19,7 +19,9 @@ import logging
import signal
import sys
from queue import Empty
from typing import TypedDict
import torch
from torch.multiprocessing import Queue
shutdown_event_counter = 0
@@ -71,3 +73,69 @@ def get_last_item_from_queue(queue: Queue):
logging.debug(f"Drained {counter} items from queue")
return item
class Transition(TypedDict):
state: dict[str, torch.Tensor]
action: torch.Tensor
reward: float
next_state: dict[str, torch.Tensor]
done: bool
truncated: bool
complementary_info: dict[str, torch.Tensor | float | int] | None = None
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
device = torch.device(device)
non_blocking = device.type == "cuda"
# Move state tensors to device
transition["state"] = {
key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items()
}
# Move action to device
transition["action"] = transition["action"].to(device, non_blocking=non_blocking)
# Move reward and done if they are tensors
if isinstance(transition["reward"], torch.Tensor):
transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking)
if isinstance(transition["done"], torch.Tensor):
transition["done"] = transition["done"].to(device, non_blocking=non_blocking)
if isinstance(transition["truncated"], torch.Tensor):
transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking)
# Move next_state tensors to device
transition["next_state"] = {
key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items()
}
# Move complementary_info tensors if present
if transition.get("complementary_info") is not None:
for key, val in transition["complementary_info"].items():
if isinstance(val, torch.Tensor):
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
elif isinstance(val, (int, float, bool)):
transition["complementary_info"][key] = torch.tensor(val, device=device)
else:
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
return transition
def move_state_dict_to_device(state_dict, device="cpu"):
"""
Recursively move all tensors in a (potentially) nested
dict/list/tuple structure to the CPU.
"""
if isinstance(state_dict, torch.Tensor):
return state_dict.to(device)
elif isinstance(state_dict, dict):
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
elif isinstance(state_dict, list):
return [move_state_dict_to_device(v, device=device) for v in state_dict]
elif isinstance(state_dict, tuple):
return tuple(move_state_dict_to_device(v, device=device) for v in state_dict)
else:
return state_dict