[Port HIl-Serl] Refactor gym-manipulator (#1034)
This commit is contained in:
@@ -182,15 +182,15 @@ class EEActionSpaceConfig:
|
||||
y_step_size: float
|
||||
z_step_size: float
|
||||
bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds
|
||||
use_gamepad: bool = False
|
||||
control_mode: str = "gamepad"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnvWrapperConfig:
|
||||
"""Configuration for environment wrappers."""
|
||||
|
||||
ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig)
|
||||
display_cameras: bool = False
|
||||
use_relative_joint_positions: bool = True
|
||||
add_joint_velocity_to_observation: bool = False
|
||||
add_current_to_observation: bool = False
|
||||
add_ee_pose_to_observation: bool = False
|
||||
@@ -199,13 +199,10 @@ class EnvWrapperConfig:
|
||||
control_time_s: float = 20.0
|
||||
fixed_reset_joint_positions: Optional[Any] = None
|
||||
reset_time_s: float = 5.0
|
||||
joint_masking_action_space: Optional[Any] = None
|
||||
ee_action_space_params: Optional[EEActionSpaceConfig] = None
|
||||
use_gripper: bool = False
|
||||
gripper_quantization_threshold: float | None = 0.8
|
||||
gripper_penalty: float = 0.0
|
||||
gripper_penalty_in_reward: bool = False
|
||||
open_gripper_on_reset: bool = False
|
||||
|
||||
|
||||
@EnvConfig.register_subclass(name="gym_manipulator")
|
||||
|
||||
@@ -308,13 +308,13 @@ def reset_environment(robot, events, reset_time_s, fps):
|
||||
)
|
||||
|
||||
|
||||
def reset_follower_position(robot: Robot, target_position):
|
||||
current_position = robot.follower_arms["main"].read("Present_Position")
|
||||
def reset_follower_position(robot_arm, target_position):
|
||||
current_position = robot_arm.read("Present_Position")
|
||||
trajectory = torch.from_numpy(
|
||||
np.linspace(current_position, target_position, 50)
|
||||
) # NOTE: 30 is just an arbitrary number
|
||||
for pose in trajectory:
|
||||
robot.send_action(pose)
|
||||
robot_arm.write("Goal_Position", pose)
|
||||
busy_wait(0.015)
|
||||
|
||||
|
||||
|
||||
@@ -443,7 +443,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem58760433331",
|
||||
port="/dev/tty.usbmodem58760431091",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
@@ -460,7 +460,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem58760431631",
|
||||
port="/dev/tty.usbmodem585A0076891",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
|
||||
@@ -167,7 +167,7 @@ from lerobot.common.robot_devices.control_utils import (
|
||||
warmup_record,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config
|
||||
from lerobot.common.robot_devices.utils import safe_disconnect
|
||||
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
|
||||
from lerobot.common.utils.utils import has_method, init_logging, log_say
|
||||
from lerobot.configs import parser
|
||||
|
||||
@@ -276,6 +276,7 @@ def record(
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
# Execute a few seconds without recording to:
|
||||
@@ -284,14 +285,7 @@ def record(
|
||||
# 3. place the cameras windows on screen
|
||||
enable_teleoperation = policy is None
|
||||
log_say("Warmup record", cfg.play_sounds)
|
||||
warmup_record(
|
||||
robot,
|
||||
events,
|
||||
enable_teleoperation,
|
||||
cfg.warmup_time_s,
|
||||
cfg.display_data,
|
||||
cfg.fps,
|
||||
)
|
||||
warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_data, cfg.fps)
|
||||
|
||||
if has_method(robot, "teleop_safety_stop"):
|
||||
robot.teleop_safety_stop()
|
||||
@@ -356,6 +350,7 @@ def replay(
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode])
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
@@ -366,6 +361,9 @@ def replay(
|
||||
action = actions[idx]["action"]
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
busy_wait(1 / cfg.fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
log_control_info(robot, dt_s, fps=cfg.fps)
|
||||
|
||||
|
||||
@@ -20,13 +20,11 @@ from functools import lru_cache
|
||||
from queue import Empty
|
||||
from statistics import mean, quantiles
|
||||
|
||||
# from lerobot.scripts.eval import eval_policy
|
||||
import grpc
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.multiprocessing import Event, Queue
|
||||
|
||||
# TODO: Remove the import of maniskill
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
@@ -39,20 +37,21 @@ from lerobot.common.utils.utils import (
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service
|
||||
from lerobot.scripts.server.buffer import (
|
||||
Transition,
|
||||
bytes_to_state_dict,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
python_object_to_bytes,
|
||||
transitions_to_bytes,
|
||||
)
|
||||
from lerobot.scripts.server.buffer import Transition
|
||||
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
||||
from lerobot.scripts.server.network_utils import (
|
||||
bytes_to_state_dict,
|
||||
python_object_to_bytes,
|
||||
receive_bytes_in_chunks,
|
||||
send_bytes_in_chunks,
|
||||
transitions_to_bytes,
|
||||
)
|
||||
from lerobot.scripts.server.utils import (
|
||||
get_last_item_from_queue,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
setup_process_handlers,
|
||||
)
|
||||
from lerobot.scripts.server.utils import get_last_item_from_queue, setup_process_handlers
|
||||
|
||||
ACTOR_SHUTDOWN_TIMEOUT = 30
|
||||
|
||||
@@ -134,21 +133,8 @@ def actor_cli(cfg: TrainPipelineConfig):
|
||||
interactions_process.start()
|
||||
receive_policy_process.start()
|
||||
|
||||
# HACK: FOR MANISKILL we do not have a reward classifier
|
||||
# TODO: Remove this once we merge into main
|
||||
reward_classifier = None
|
||||
# if (
|
||||
# cfg.env.reward_classifier["pretrained_path"] is not None
|
||||
# and cfg.env.reward_classifier["config_path"] is not None
|
||||
# ):
|
||||
# reward_classifier = get_classifier(
|
||||
# pretrained_path=cfg.env.reward_classifier["pretrained_path"],
|
||||
# config_path=cfg.env.reward_classifier["config_path"],
|
||||
# )
|
||||
|
||||
act_with_policy(
|
||||
cfg=cfg,
|
||||
reward_classifier=reward_classifier,
|
||||
shutdown_event=shutdown_event,
|
||||
parameters_queue=parameters_queue,
|
||||
transitions_queue=transitions_queue,
|
||||
@@ -183,7 +169,6 @@ def actor_cli(cfg: TrainPipelineConfig):
|
||||
|
||||
def act_with_policy(
|
||||
cfg: TrainPipelineConfig,
|
||||
reward_classifier: nn.Module,
|
||||
shutdown_event: any, # Event,
|
||||
parameters_queue: Queue,
|
||||
transitions_queue: Queue,
|
||||
@@ -197,7 +182,6 @@ def act_with_policy(
|
||||
|
||||
Args:
|
||||
cfg: Configuration settings for the interaction process.
|
||||
reward_classifier: Reward classifier to use for the interaction process.
|
||||
shutdown_event: Event to check if the process should shutdown.
|
||||
parameters_queue: Queue to receive updated network parameters from the learner.
|
||||
transitions_queue: Queue to send transitions to the learner.
|
||||
@@ -262,16 +246,10 @@ def act_with_policy(
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
|
||||
next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy())
|
||||
else:
|
||||
# TODO (azouitine): Make a custom space for torch tensor
|
||||
action = online_env.action_space.sample()
|
||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
||||
|
||||
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
|
||||
action = (
|
||||
torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0)
|
||||
)
|
||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
||||
|
||||
sum_reward_episode += float(reward)
|
||||
# Increment total steps counter for intervention rate
|
||||
@@ -286,11 +264,6 @@ def act_with_policy(
|
||||
# Increment intervention steps counter
|
||||
episode_intervention_steps += 1
|
||||
|
||||
# Check for NaN values in observations
|
||||
for key, tensor in obs.items():
|
||||
if torch.isnan(tensor).any():
|
||||
logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}")
|
||||
|
||||
list_transition_to_send_to_learner.append(
|
||||
Transition(
|
||||
state=obs,
|
||||
|
||||
@@ -15,26 +15,15 @@
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import io
|
||||
import pickle # nosec B403: Safe usage of pickle
|
||||
from contextlib import suppress
|
||||
from typing import Any, Callable, Optional, Sequence, TypedDict
|
||||
from typing import Callable, Optional, Sequence, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
class Transition(TypedDict):
|
||||
state: dict[str, torch.Tensor]
|
||||
action: torch.Tensor
|
||||
reward: float
|
||||
next_state: dict[str, torch.Tensor]
|
||||
done: bool
|
||||
truncated: bool
|
||||
complementary_info: dict[str, torch.Tensor | float | int] | None = None
|
||||
from lerobot.scripts.server.utils import Transition
|
||||
|
||||
|
||||
class BatchTransition(TypedDict):
|
||||
@@ -47,103 +36,6 @@ class BatchTransition(TypedDict):
|
||||
complementary_info: dict[str, torch.Tensor | float | int] | None = None
|
||||
|
||||
|
||||
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
|
||||
device = torch.device(device)
|
||||
non_blocking = device.type == "cuda"
|
||||
|
||||
# Move state tensors to device
|
||||
transition["state"] = {
|
||||
key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items()
|
||||
}
|
||||
|
||||
# Move action to device
|
||||
transition["action"] = transition["action"].to(device, non_blocking=non_blocking)
|
||||
|
||||
# Move reward and done if they are tensors
|
||||
if isinstance(transition["reward"], torch.Tensor):
|
||||
transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking)
|
||||
|
||||
if isinstance(transition["done"], torch.Tensor):
|
||||
transition["done"] = transition["done"].to(device, non_blocking=non_blocking)
|
||||
|
||||
if isinstance(transition["truncated"], torch.Tensor):
|
||||
transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking)
|
||||
|
||||
# Move next_state tensors to device
|
||||
transition["next_state"] = {
|
||||
key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items()
|
||||
}
|
||||
|
||||
# Move complementary_info tensors if present
|
||||
if transition.get("complementary_info") is not None:
|
||||
for key, val in transition["complementary_info"].items():
|
||||
if isinstance(val, torch.Tensor):
|
||||
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
|
||||
elif isinstance(val, (int, float, bool)):
|
||||
transition["complementary_info"][key] = torch.tensor(val, device=device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
|
||||
return transition
|
||||
|
||||
|
||||
def move_state_dict_to_device(state_dict, device="cpu"):
|
||||
"""
|
||||
Recursively move all tensors in a (potentially) nested
|
||||
dict/list/tuple structure to the CPU.
|
||||
"""
|
||||
if isinstance(state_dict, torch.Tensor):
|
||||
return state_dict.to(device)
|
||||
elif isinstance(state_dict, dict):
|
||||
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
|
||||
elif isinstance(state_dict, list):
|
||||
return [move_state_dict_to_device(v, device=device) for v in state_dict]
|
||||
elif isinstance(state_dict, tuple):
|
||||
return tuple(move_state_dict_to_device(v, device=device) for v in state_dict)
|
||||
else:
|
||||
return state_dict
|
||||
|
||||
|
||||
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
|
||||
"""Convert model state dict to flat array for transmission"""
|
||||
buffer = io.BytesIO()
|
||||
|
||||
torch.save(state_dict, buffer)
|
||||
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
return torch.load(buffer) # nosec B614: Safe usage of torch.load
|
||||
|
||||
|
||||
def python_object_to_bytes(python_object: Any) -> bytes:
|
||||
return pickle.dumps(python_object)
|
||||
|
||||
|
||||
def bytes_to_python_object(buffer: bytes) -> Any:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load
|
||||
# Add validation checks here
|
||||
return obj
|
||||
|
||||
|
||||
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
transitions = torch.load(buffer) # nosec B614: Safe usage of torch.load
|
||||
# Add validation checks here
|
||||
return transitions
|
||||
|
||||
|
||||
def transitions_to_bytes(transitions: list[Transition]) -> bytes:
|
||||
buffer = io.BytesIO()
|
||||
torch.save(transitions, buffer)
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
|
||||
"""
|
||||
Perform a per-image random crop over a batch of images in a vectorized way.
|
||||
@@ -514,7 +406,6 @@ class ReplayBuffer:
|
||||
device: str = "cuda:0",
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
capacity: Optional[int] = None,
|
||||
action_mask: Optional[Sequence[int]] = None,
|
||||
image_augmentation_function: Optional[Callable] = None,
|
||||
use_drq: bool = True,
|
||||
storage_device: str = "cpu",
|
||||
@@ -566,13 +457,6 @@ class ReplayBuffer:
|
||||
first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
|
||||
first_action = first_transition["action"].to(device)
|
||||
|
||||
# Apply action mask/delta if needed
|
||||
if action_mask is not None:
|
||||
if first_action.dim() == 1:
|
||||
first_action = first_action[action_mask]
|
||||
else:
|
||||
first_action = first_action[:, action_mask]
|
||||
|
||||
# Get complementary info if available
|
||||
first_complementary_info = None
|
||||
if (
|
||||
@@ -597,8 +481,6 @@ class ReplayBuffer:
|
||||
data[k] = v.to(storage_device)
|
||||
|
||||
action = data["action"]
|
||||
if action_mask is not None:
|
||||
action = action[action_mask] if action.dim() == 1 else action[:, action_mask]
|
||||
|
||||
replay_buffer.add(
|
||||
state=data["state"],
|
||||
|
||||
@@ -524,7 +524,9 @@ def teleoperate_inverse_kinematics_with_leader(robot, fps=10):
|
||||
leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
|
||||
|
||||
desired_ee_pos = leader_ee
|
||||
target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True)
|
||||
target_joint_state = kinematics.ik(
|
||||
joint_positions, desired_ee_pos, position_only=True, fk_func=kinematics.fk_gripper_tip
|
||||
)
|
||||
robot.send_action(torch.from_numpy(target_joint_state))
|
||||
logging.info(f"Leader EE: {leader_ee[:3, 3]}, Follower EE: {ee_pos[:3, 3]}")
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
@@ -544,6 +546,8 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
|
||||
initial_leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
|
||||
|
||||
desired_ee_pos = np.diag(np.ones(4))
|
||||
joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
fixed_ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
|
||||
while time.perf_counter() - timestep < 60.0:
|
||||
loop_start_time = time.perf_counter()
|
||||
@@ -561,25 +565,26 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
|
||||
# Calculate delta between leader and follower end-effectors
|
||||
# Scaling factor can be adjusted for sensitivity
|
||||
scaling_factor = 1.0
|
||||
ee_delta = (leader_ee - initial_leader_ee) * scaling_factor
|
||||
ee_delta = -np.clip((leader_ee - initial_leader_ee) * scaling_factor, -0.05, 0.05)
|
||||
|
||||
# Apply delta to current position
|
||||
desired_ee_pos[0, 3] = current_ee_pos[0, 3] + ee_delta[0, 3]
|
||||
desired_ee_pos[1, 3] = current_ee_pos[1, 3] + ee_delta[1, 3]
|
||||
desired_ee_pos[2, 3] = current_ee_pos[2, 3] + ee_delta[2, 3]
|
||||
desired_ee_pos[0, 3] = fixed_ee_pos[0, 3] # current_ee_pos[0, 3] + ee_delta[0, 3] * 0
|
||||
desired_ee_pos[1, 3] = fixed_ee_pos[1, 3] # current_ee_pos[1, 3] + ee_delta[1, 3] * 0
|
||||
desired_ee_pos[2, 3] = current_ee_pos[2, 3] - ee_delta[2, 3]
|
||||
|
||||
if np.any(np.abs(ee_delta[:3, 3]) > 0.01):
|
||||
# Compute joint targets via inverse kinematics
|
||||
target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True)
|
||||
# Compute joint targets via inverse kinematics
|
||||
target_joint_state = kinematics.ik(
|
||||
joint_positions, desired_ee_pos, position_only=True, fk_func=kinematics.fk_gripper_tip
|
||||
)
|
||||
|
||||
initial_leader_ee = leader_ee.copy()
|
||||
initial_leader_ee = leader_ee.copy()
|
||||
|
||||
# Send command to robot
|
||||
robot.send_action(torch.from_numpy(target_joint_state))
|
||||
# Send command to robot
|
||||
robot.send_action(torch.from_numpy(target_joint_state))
|
||||
|
||||
# Logging
|
||||
logging.info(f"Current EE: {current_ee_pos[:3, 3]}, Desired EE: {desired_ee_pos[:3, 3]}")
|
||||
logging.info(f"Delta EE: {ee_delta[:3, 3]}")
|
||||
# Logging
|
||||
logging.info(f"Current EE: {current_ee_pos[:3, 3]}, Desired EE: {desired_ee_pos[:3, 3]}")
|
||||
logging.info(f"Delta EE: {ee_delta[:3, 3]}")
|
||||
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
@@ -715,8 +720,8 @@ if __name__ == "__main__":
|
||||
"gamepad",
|
||||
"keyboard_gym",
|
||||
"gamepad_gym",
|
||||
"leader_delta",
|
||||
"leader",
|
||||
"leader_abs",
|
||||
],
|
||||
help="Control mode to use",
|
||||
)
|
||||
@@ -768,11 +773,11 @@ if __name__ == "__main__":
|
||||
env = make_robot_env(cfg, robot)
|
||||
teleoperate_gym_env(env, controller, fps=cfg.fps)
|
||||
|
||||
elif args.mode == "leader":
|
||||
elif args.mode == "leader_delta":
|
||||
# Leader-follower modes don't use controllers
|
||||
teleoperate_delta_inverse_kinematics_with_leader(robot)
|
||||
|
||||
elif args.mode == "leader_abs":
|
||||
elif args.mode == "leader":
|
||||
teleoperate_inverse_kinematics_with_leader(robot)
|
||||
|
||||
finally:
|
||||
|
||||
@@ -11,7 +11,7 @@ from lerobot.configs import parser
|
||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||
|
||||
follower_port = "/dev/tty.usbmodem58760431631"
|
||||
leader_port = "/dev/tty.usbmodem58760433331"
|
||||
leader_port = "/dev/tty.usbmodem585A0077921"
|
||||
|
||||
|
||||
def find_joint_bounds(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -23,8 +23,6 @@ from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
import grpc
|
||||
|
||||
# Import generated stubs
|
||||
import hilserl_pb2_grpc # type: ignore
|
||||
import torch
|
||||
from termcolor import colored
|
||||
@@ -39,8 +37,6 @@ from lerobot.common.constants import (
|
||||
TRAINING_STATE_DIR,
|
||||
)
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
|
||||
# TODO: Remove the import of maniskill
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
@@ -62,16 +58,17 @@ from lerobot.common.utils.wandb_utils import WandBLogger
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.scripts.server import learner_service
|
||||
from lerobot.scripts.server.buffer import (
|
||||
ReplayBuffer,
|
||||
from lerobot.scripts.server.buffer import ReplayBuffer, concatenate_batch_transitions
|
||||
from lerobot.scripts.server.network_utils import (
|
||||
bytes_to_python_object,
|
||||
bytes_to_transitions,
|
||||
concatenate_batch_transitions,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
state_to_bytes,
|
||||
)
|
||||
from lerobot.scripts.server.utils import setup_process_handlers
|
||||
from lerobot.scripts.server.utils import (
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
setup_process_handlers,
|
||||
)
|
||||
|
||||
LOG_PREFIX = "[LEARNER]"
|
||||
|
||||
@@ -307,17 +304,10 @@ def add_actor_information_and_train(
|
||||
offline_replay_buffer = None
|
||||
|
||||
if cfg.dataset is not None:
|
||||
active_action_dims = None
|
||||
# TODO: FIX THIS
|
||||
if cfg.env.wrapper.joint_masking_action_space is not None:
|
||||
active_action_dims = [
|
||||
i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask
|
||||
]
|
||||
offline_replay_buffer = initialize_offline_replay_buffer(
|
||||
cfg=cfg,
|
||||
device=device,
|
||||
storage_device=storage_device,
|
||||
active_action_dims=active_action_dims,
|
||||
)
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
@@ -342,7 +332,6 @@ def add_actor_information_and_train(
|
||||
break
|
||||
|
||||
# Process all available transitions to the replay buffer, send by the actor server
|
||||
logging.debug("[LEARNER] Waiting for transitions")
|
||||
process_transitions(
|
||||
transition_queue=transition_queue,
|
||||
replay_buffer=replay_buffer,
|
||||
@@ -351,35 +340,29 @@ def add_actor_information_and_train(
|
||||
dataset_repo_id=dataset_repo_id,
|
||||
shutdown_event=shutdown_event,
|
||||
)
|
||||
logging.debug("[LEARNER] Received transitions")
|
||||
|
||||
# Process all available interaction messages sent by the actor server
|
||||
logging.debug("[LEARNER] Waiting for interactions")
|
||||
interaction_message = process_interaction_messages(
|
||||
interaction_message_queue=interaction_message_queue,
|
||||
interaction_step_shift=interaction_step_shift,
|
||||
wandb_logger=wandb_logger,
|
||||
shutdown_event=shutdown_event,
|
||||
)
|
||||
logging.debug("[LEARNER] Received interactions")
|
||||
|
||||
# Wait until the replay buffer has enough samples to start training
|
||||
if len(replay_buffer) < online_step_before_learning:
|
||||
continue
|
||||
|
||||
if online_iterator is None:
|
||||
logging.debug("[LEARNER] Initializing online replay buffer iterator")
|
||||
online_iterator = replay_buffer.get_iterator(
|
||||
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
|
||||
)
|
||||
|
||||
if offline_replay_buffer is not None and offline_iterator is None:
|
||||
logging.debug("[LEARNER] Initializing offline replay buffer iterator")
|
||||
offline_iterator = offline_replay_buffer.get_iterator(
|
||||
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
|
||||
)
|
||||
|
||||
logging.debug("[LEARNER] Starting optimization loop")
|
||||
time_for_one_optimization_step = time.time()
|
||||
for _ in range(utd_ratio - 1):
|
||||
# Sample from the iterators
|
||||
@@ -967,7 +950,6 @@ def initialize_offline_replay_buffer(
|
||||
cfg: TrainPipelineConfig,
|
||||
device: str,
|
||||
storage_device: str,
|
||||
active_action_dims: list[int] | None = None,
|
||||
) -> ReplayBuffer:
|
||||
"""
|
||||
Initialize an offline replay buffer from a dataset.
|
||||
@@ -976,7 +958,6 @@ def initialize_offline_replay_buffer(
|
||||
cfg (TrainPipelineConfig): Training configuration
|
||||
device (str): Device to store tensors on
|
||||
storage_device (str): Device for storage optimization
|
||||
active_action_dims (list[int] | None): Active action dimensions for masking
|
||||
|
||||
Returns:
|
||||
ReplayBuffer: Initialized offline replay buffer
|
||||
@@ -997,7 +978,6 @@ def initialize_offline_replay_buffer(
|
||||
offline_dataset,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_features.keys(),
|
||||
action_mask=active_action_dims,
|
||||
storage_device=storage_device,
|
||||
optimize_memory=True,
|
||||
capacity=cfg.policy.offline_buffer_capacity,
|
||||
@@ -1096,44 +1076,6 @@ def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
|
||||
parameters_queue.put(state_bytes)
|
||||
|
||||
|
||||
def check_weight_gradients(module: nn.Module) -> dict[str, bool]:
|
||||
"""
|
||||
Checks whether each parameter in the module has a gradient.
|
||||
|
||||
Args:
|
||||
module (nn.Module): A PyTorch module whose parameters will be inspected.
|
||||
|
||||
Returns:
|
||||
dict[str, bool]: A dictionary where each key is the parameter name and the value is
|
||||
True if the parameter has an associated gradient (i.e. .grad is not None),
|
||||
otherwise False.
|
||||
"""
|
||||
grad_status = {}
|
||||
for name, param in module.named_parameters():
|
||||
grad_status[name] = param.grad is not None
|
||||
return grad_status
|
||||
|
||||
|
||||
def get_overlapping_parameters(model: nn.Module, grad_status: dict[str, bool]) -> dict[str, bool]:
|
||||
"""
|
||||
Returns a dictionary of parameters (from actor) that also exist in the grad_status dictionary.
|
||||
|
||||
Args:
|
||||
actor (nn.Module): The actor model.
|
||||
grad_status (dict[str, bool]): A dictionary where keys are parameter names and values indicate
|
||||
whether each parameter has a gradient.
|
||||
|
||||
Returns:
|
||||
dict[str, bool]: A dictionary containing only the overlapping parameter names and their gradient status.
|
||||
"""
|
||||
# Get actor parameter names as a set.
|
||||
model_param_names = {name for name, _ in model.named_parameters()}
|
||||
|
||||
# Intersect parameter names between actor and grad_status.
|
||||
overlapping = {name: grad_status[name] for name in grad_status if name in model_param_names}
|
||||
return overlapping
|
||||
|
||||
|
||||
def process_interaction_message(
|
||||
message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None
|
||||
):
|
||||
|
||||
@@ -1,221 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
import einops
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from mani_skill.utils.wrappers.record import RecordEpisode
|
||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
|
||||
from lerobot.common.envs.configs import ManiskillEnvConfig
|
||||
|
||||
|
||||
def preprocess_maniskill_observation(
|
||||
observations: dict[str, np.ndarray],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
Args:
|
||||
observation: Dictionary of observation batches from a Gym vector environment.
|
||||
Returns:
|
||||
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
||||
"""
|
||||
# map to expected inputs for the policy
|
||||
return_observations = {}
|
||||
# TODO: You have to merge all tensors from agent key and extra key
|
||||
# You don't keep sensor param key in the observation
|
||||
# And you keep sensor data rgb
|
||||
q_pos = observations["agent"]["qpos"]
|
||||
q_vel = observations["agent"]["qvel"]
|
||||
tcp_pos = observations["extra"]["tcp_pose"]
|
||||
img = observations["sensor_data"]["base_camera"]["rgb"]
|
||||
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
|
||||
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
|
||||
|
||||
return_observations["observation.image"] = img
|
||||
return_observations["observation.state"] = state
|
||||
return return_observations
|
||||
|
||||
|
||||
class ManiSkillObservationWrapper(gym.ObservationWrapper):
|
||||
def __init__(self, env, device: torch.device = "cuda"):
|
||||
super().__init__(env)
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
self.device = device
|
||||
|
||||
def observation(self, observation):
|
||||
observation = preprocess_maniskill_observation(observation)
|
||||
observation = {k: v.to(self.device) for k, v in observation.items()}
|
||||
return observation
|
||||
|
||||
|
||||
class ManiSkillCompat(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
new_action_space_shape = env.action_space.shape[-1]
|
||||
new_low = np.squeeze(env.action_space.low, axis=0)
|
||||
new_high = np.squeeze(env.action_space.high, axis=0)
|
||||
self.action_space = gym.spaces.Box(low=new_low, high=new_high, shape=(new_action_space_shape,))
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[Any, dict[str, Any]]:
|
||||
options = {}
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
reward = reward.item()
|
||||
terminated = terminated.item()
|
||||
truncated = truncated.item()
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
|
||||
class ManiSkillActionWrapper(gym.ActionWrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
self.action_space = gym.spaces.Tuple(spaces=(env.action_space, gym.spaces.Discrete(2)))
|
||||
|
||||
def action(self, action):
|
||||
action, telop = action
|
||||
return action
|
||||
|
||||
|
||||
class ManiSkillMultiplyActionWrapper(gym.Wrapper):
|
||||
def __init__(self, env, multiply_factor: float = 1):
|
||||
super().__init__(env)
|
||||
self.multiply_factor = multiply_factor
|
||||
action_space_agent: gym.spaces.Box = env.action_space[0]
|
||||
action_space_agent.low = action_space_agent.low * multiply_factor
|
||||
action_space_agent.high = action_space_agent.high * multiply_factor
|
||||
self.action_space = gym.spaces.Tuple(spaces=(action_space_agent, gym.spaces.Discrete(2)))
|
||||
|
||||
def step(self, action):
|
||||
if isinstance(action, tuple):
|
||||
action, telop = action
|
||||
else:
|
||||
telop = 0
|
||||
action = action / self.multiply_factor
|
||||
obs, reward, terminated, truncated, info = self.env.step((action, telop))
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
|
||||
class BatchCompatibleWrapper(gym.ObservationWrapper):
|
||||
"""Ensures observations are batch-compatible by adding a batch dimension if necessary."""
|
||||
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
|
||||
def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
for key in observation:
|
||||
if "image" in key and observation[key].dim() == 3:
|
||||
observation[key] = observation[key].unsqueeze(0)
|
||||
if "state" in key and observation[key].dim() == 1:
|
||||
observation[key] = observation[key].unsqueeze(0)
|
||||
return observation
|
||||
|
||||
|
||||
class TimeLimitWrapper(gym.Wrapper):
|
||||
"""Adds a time limit to the environment based on fps and control_time."""
|
||||
|
||||
def __init__(self, env, control_time_s, fps):
|
||||
super().__init__(env)
|
||||
self.control_time_s = control_time_s
|
||||
self.fps = fps
|
||||
self.max_episode_steps = int(self.control_time_s * self.fps)
|
||||
self.current_step = 0
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
self.current_step += 1
|
||||
|
||||
if self.current_step >= self.max_episode_steps:
|
||||
terminated = True
|
||||
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, *, seed=None, options=None):
|
||||
self.current_step = 0
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
|
||||
class ManiskillMockGripperWrapper(gym.Wrapper):
|
||||
def __init__(self, env, nb_discrete_actions: int = 3):
|
||||
super().__init__(env)
|
||||
new_shape = env.action_space[0].shape[0] + 1
|
||||
new_low = np.concatenate([env.action_space[0].low, [0]])
|
||||
new_high = np.concatenate([env.action_space[0].high, [nb_discrete_actions - 1]])
|
||||
action_space_agent = gym.spaces.Box(low=new_low, high=new_high, shape=(new_shape,))
|
||||
self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1]))
|
||||
|
||||
def step(self, action):
|
||||
if isinstance(action, tuple):
|
||||
action_agent, telop_action = action
|
||||
else:
|
||||
telop_action = 0
|
||||
action_agent = action
|
||||
real_action = action_agent[:-1]
|
||||
final_action = (real_action, telop_action)
|
||||
obs, reward, terminated, truncated, info = self.env.step(final_action)
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
|
||||
def make_maniskill(
|
||||
cfg: ManiskillEnvConfig,
|
||||
n_envs: int | None = None,
|
||||
) -> gym.Env:
|
||||
"""
|
||||
Factory function to create a ManiSkill environment with standard wrappers.
|
||||
|
||||
Args:
|
||||
cfg: Configuration for the ManiSkill environment
|
||||
n_envs: Number of parallel environments
|
||||
|
||||
Returns:
|
||||
A wrapped ManiSkill environment
|
||||
"""
|
||||
env = gym.make(
|
||||
cfg.task,
|
||||
obs_mode=cfg.obs_type,
|
||||
control_mode=cfg.control_mode,
|
||||
render_mode=cfg.render_mode,
|
||||
sensor_configs={"width": cfg.image_size, "height": cfg.image_size},
|
||||
num_envs=n_envs,
|
||||
)
|
||||
|
||||
# Add video recording if enabled
|
||||
if cfg.video_record.enabled:
|
||||
env = RecordEpisode(
|
||||
env,
|
||||
output_dir=cfg.video_record.record_dir,
|
||||
save_trajectory=True,
|
||||
trajectory_name=cfg.video_record.trajectory_name,
|
||||
save_video=True,
|
||||
video_fps=30,
|
||||
)
|
||||
|
||||
# Add observation and image processing
|
||||
env = ManiSkillObservationWrapper(env, device=cfg.device)
|
||||
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
|
||||
env._max_episode_steps = env.max_episode_steps = cfg.episode_length
|
||||
env.unwrapped.metadata["render_fps"] = cfg.fps
|
||||
|
||||
# Add compatibility wrappers
|
||||
env = ManiSkillCompat(env)
|
||||
env = ManiSkillActionWrapper(env)
|
||||
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control
|
||||
if cfg.mock_gripper:
|
||||
env = ManiskillMockGripperWrapper(env, nb_discrete_actions=3)
|
||||
|
||||
return env
|
||||
@@ -17,10 +17,14 @@
|
||||
|
||||
import io
|
||||
import logging
|
||||
import pickle # nosec B403: Safe usage for internal serialization only
|
||||
from multiprocessing import Event, Queue
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.scripts.server import hilserl_pb2
|
||||
from lerobot.scripts.server.utils import Transition
|
||||
|
||||
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
|
||||
|
||||
@@ -60,7 +64,7 @@ def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = ""
|
||||
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
|
||||
|
||||
|
||||
def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""):
|
||||
def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""): # type: ignore
|
||||
bytes_buffer = io.BytesIO()
|
||||
step = 0
|
||||
|
||||
@@ -93,3 +97,44 @@ def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_p
|
||||
step = 0
|
||||
|
||||
logging.debug(f"{log_prefix} Queue updated")
|
||||
|
||||
|
||||
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
|
||||
"""Convert model state dict to flat array for transmission"""
|
||||
buffer = io.BytesIO()
|
||||
|
||||
torch.save(state_dict, buffer)
|
||||
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
return torch.load(buffer) # nosec B614: Safe usage of torch.load
|
||||
|
||||
|
||||
def python_object_to_bytes(python_object: Any) -> bytes:
|
||||
return pickle.dumps(python_object)
|
||||
|
||||
|
||||
def bytes_to_python_object(buffer: bytes) -> Any:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load
|
||||
# Add validation checks here
|
||||
return obj
|
||||
|
||||
|
||||
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
transitions = torch.load(buffer) # nosec B614: Safe usage of torch.load
|
||||
# Add validation checks here
|
||||
return transitions
|
||||
|
||||
|
||||
def transitions_to_bytes(transitions: list[Transition]) -> bytes:
|
||||
buffer = io.BytesIO()
|
||||
torch.save(transitions, buffer)
|
||||
return buffer.getvalue()
|
||||
|
||||
@@ -19,7 +19,9 @@ import logging
|
||||
import signal
|
||||
import sys
|
||||
from queue import Empty
|
||||
from typing import TypedDict
|
||||
|
||||
import torch
|
||||
from torch.multiprocessing import Queue
|
||||
|
||||
shutdown_event_counter = 0
|
||||
@@ -71,3 +73,69 @@ def get_last_item_from_queue(queue: Queue):
|
||||
logging.debug(f"Drained {counter} items from queue")
|
||||
|
||||
return item
|
||||
|
||||
|
||||
class Transition(TypedDict):
|
||||
state: dict[str, torch.Tensor]
|
||||
action: torch.Tensor
|
||||
reward: float
|
||||
next_state: dict[str, torch.Tensor]
|
||||
done: bool
|
||||
truncated: bool
|
||||
complementary_info: dict[str, torch.Tensor | float | int] | None = None
|
||||
|
||||
|
||||
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
|
||||
device = torch.device(device)
|
||||
non_blocking = device.type == "cuda"
|
||||
|
||||
# Move state tensors to device
|
||||
transition["state"] = {
|
||||
key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items()
|
||||
}
|
||||
|
||||
# Move action to device
|
||||
transition["action"] = transition["action"].to(device, non_blocking=non_blocking)
|
||||
|
||||
# Move reward and done if they are tensors
|
||||
if isinstance(transition["reward"], torch.Tensor):
|
||||
transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking)
|
||||
|
||||
if isinstance(transition["done"], torch.Tensor):
|
||||
transition["done"] = transition["done"].to(device, non_blocking=non_blocking)
|
||||
|
||||
if isinstance(transition["truncated"], torch.Tensor):
|
||||
transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking)
|
||||
|
||||
# Move next_state tensors to device
|
||||
transition["next_state"] = {
|
||||
key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items()
|
||||
}
|
||||
|
||||
# Move complementary_info tensors if present
|
||||
if transition.get("complementary_info") is not None:
|
||||
for key, val in transition["complementary_info"].items():
|
||||
if isinstance(val, torch.Tensor):
|
||||
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
|
||||
elif isinstance(val, (int, float, bool)):
|
||||
transition["complementary_info"][key] = torch.tensor(val, device=device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
|
||||
return transition
|
||||
|
||||
|
||||
def move_state_dict_to_device(state_dict, device="cpu"):
|
||||
"""
|
||||
Recursively move all tensors in a (potentially) nested
|
||||
dict/list/tuple structure to the CPU.
|
||||
"""
|
||||
if isinstance(state_dict, torch.Tensor):
|
||||
return state_dict.to(device)
|
||||
elif isinstance(state_dict, dict):
|
||||
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
|
||||
elif isinstance(state_dict, list):
|
||||
return [move_state_dict_to_device(v, device=device) for v in state_dict]
|
||||
elif isinstance(state_dict, tuple):
|
||||
return tuple(move_state_dict_to_device(v, device=device) for v in state_dict)
|
||||
else:
|
||||
return state_dict
|
||||
|
||||
Reference in New Issue
Block a user