forked from tangger/lerobot
[Port HIl-Serl] Refactor gym-manipulator (#1034)
This commit is contained in:
@@ -182,15 +182,15 @@ class EEActionSpaceConfig:
|
|||||||
y_step_size: float
|
y_step_size: float
|
||||||
z_step_size: float
|
z_step_size: float
|
||||||
bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds
|
bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds
|
||||||
use_gamepad: bool = False
|
control_mode: str = "gamepad"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EnvWrapperConfig:
|
class EnvWrapperConfig:
|
||||||
"""Configuration for environment wrappers."""
|
"""Configuration for environment wrappers."""
|
||||||
|
|
||||||
|
ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig)
|
||||||
display_cameras: bool = False
|
display_cameras: bool = False
|
||||||
use_relative_joint_positions: bool = True
|
|
||||||
add_joint_velocity_to_observation: bool = False
|
add_joint_velocity_to_observation: bool = False
|
||||||
add_current_to_observation: bool = False
|
add_current_to_observation: bool = False
|
||||||
add_ee_pose_to_observation: bool = False
|
add_ee_pose_to_observation: bool = False
|
||||||
@@ -199,13 +199,10 @@ class EnvWrapperConfig:
|
|||||||
control_time_s: float = 20.0
|
control_time_s: float = 20.0
|
||||||
fixed_reset_joint_positions: Optional[Any] = None
|
fixed_reset_joint_positions: Optional[Any] = None
|
||||||
reset_time_s: float = 5.0
|
reset_time_s: float = 5.0
|
||||||
joint_masking_action_space: Optional[Any] = None
|
|
||||||
ee_action_space_params: Optional[EEActionSpaceConfig] = None
|
|
||||||
use_gripper: bool = False
|
use_gripper: bool = False
|
||||||
gripper_quantization_threshold: float | None = 0.8
|
gripper_quantization_threshold: float | None = 0.8
|
||||||
gripper_penalty: float = 0.0
|
gripper_penalty: float = 0.0
|
||||||
gripper_penalty_in_reward: bool = False
|
gripper_penalty_in_reward: bool = False
|
||||||
open_gripper_on_reset: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
@EnvConfig.register_subclass(name="gym_manipulator")
|
@EnvConfig.register_subclass(name="gym_manipulator")
|
||||||
|
|||||||
@@ -308,13 +308,13 @@ def reset_environment(robot, events, reset_time_s, fps):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def reset_follower_position(robot: Robot, target_position):
|
def reset_follower_position(robot_arm, target_position):
|
||||||
current_position = robot.follower_arms["main"].read("Present_Position")
|
current_position = robot_arm.read("Present_Position")
|
||||||
trajectory = torch.from_numpy(
|
trajectory = torch.from_numpy(
|
||||||
np.linspace(current_position, target_position, 50)
|
np.linspace(current_position, target_position, 50)
|
||||||
) # NOTE: 30 is just an arbitrary number
|
) # NOTE: 30 is just an arbitrary number
|
||||||
for pose in trajectory:
|
for pose in trajectory:
|
||||||
robot.send_action(pose)
|
robot_arm.write("Goal_Position", pose)
|
||||||
busy_wait(0.015)
|
busy_wait(0.015)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -443,7 +443,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
|
|||||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"main": FeetechMotorsBusConfig(
|
"main": FeetechMotorsBusConfig(
|
||||||
port="/dev/tty.usbmodem58760433331",
|
port="/dev/tty.usbmodem58760431091",
|
||||||
motors={
|
motors={
|
||||||
# name: (index, model)
|
# name: (index, model)
|
||||||
"shoulder_pan": [1, "sts3215"],
|
"shoulder_pan": [1, "sts3215"],
|
||||||
@@ -460,7 +460,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
|
|||||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"main": FeetechMotorsBusConfig(
|
"main": FeetechMotorsBusConfig(
|
||||||
port="/dev/tty.usbmodem58760431631",
|
port="/dev/tty.usbmodem585A0076891",
|
||||||
motors={
|
motors={
|
||||||
# name: (index, model)
|
# name: (index, model)
|
||||||
"shoulder_pan": [1, "sts3215"],
|
"shoulder_pan": [1, "sts3215"],
|
||||||
|
|||||||
@@ -167,7 +167,7 @@ from lerobot.common.robot_devices.control_utils import (
|
|||||||
warmup_record,
|
warmup_record,
|
||||||
)
|
)
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config
|
from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config
|
||||||
from lerobot.common.robot_devices.utils import safe_disconnect
|
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
|
||||||
from lerobot.common.utils.utils import has_method, init_logging, log_say
|
from lerobot.common.utils.utils import has_method, init_logging, log_say
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
|
|
||||||
@@ -276,6 +276,7 @@ def record(
|
|||||||
|
|
||||||
if not robot.is_connected:
|
if not robot.is_connected:
|
||||||
robot.connect()
|
robot.connect()
|
||||||
|
|
||||||
listener, events = init_keyboard_listener()
|
listener, events = init_keyboard_listener()
|
||||||
|
|
||||||
# Execute a few seconds without recording to:
|
# Execute a few seconds without recording to:
|
||||||
@@ -284,14 +285,7 @@ def record(
|
|||||||
# 3. place the cameras windows on screen
|
# 3. place the cameras windows on screen
|
||||||
enable_teleoperation = policy is None
|
enable_teleoperation = policy is None
|
||||||
log_say("Warmup record", cfg.play_sounds)
|
log_say("Warmup record", cfg.play_sounds)
|
||||||
warmup_record(
|
warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_data, cfg.fps)
|
||||||
robot,
|
|
||||||
events,
|
|
||||||
enable_teleoperation,
|
|
||||||
cfg.warmup_time_s,
|
|
||||||
cfg.display_data,
|
|
||||||
cfg.fps,
|
|
||||||
)
|
|
||||||
|
|
||||||
if has_method(robot, "teleop_safety_stop"):
|
if has_method(robot, "teleop_safety_stop"):
|
||||||
robot.teleop_safety_stop()
|
robot.teleop_safety_stop()
|
||||||
@@ -356,6 +350,7 @@ def replay(
|
|||||||
|
|
||||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode])
|
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode])
|
||||||
actions = dataset.hf_dataset.select_columns("action")
|
actions = dataset.hf_dataset.select_columns("action")
|
||||||
|
|
||||||
if not robot.is_connected:
|
if not robot.is_connected:
|
||||||
robot.connect()
|
robot.connect()
|
||||||
|
|
||||||
@@ -366,6 +361,9 @@ def replay(
|
|||||||
action = actions[idx]["action"]
|
action = actions[idx]["action"]
|
||||||
robot.send_action(action)
|
robot.send_action(action)
|
||||||
|
|
||||||
|
dt_s = time.perf_counter() - start_episode_t
|
||||||
|
busy_wait(1 / cfg.fps - dt_s)
|
||||||
|
|
||||||
dt_s = time.perf_counter() - start_episode_t
|
dt_s = time.perf_counter() - start_episode_t
|
||||||
log_control_info(robot, dt_s, fps=cfg.fps)
|
log_control_info(robot, dt_s, fps=cfg.fps)
|
||||||
|
|
||||||
|
|||||||
@@ -20,13 +20,11 @@ from functools import lru_cache
|
|||||||
from queue import Empty
|
from queue import Empty
|
||||||
from statistics import mean, quantiles
|
from statistics import mean, quantiles
|
||||||
|
|
||||||
# from lerobot.scripts.eval import eval_policy
|
|
||||||
import grpc
|
import grpc
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.multiprocessing import Event, Queue
|
from torch.multiprocessing import Event, Queue
|
||||||
|
|
||||||
# TODO: Remove the import of maniskill
|
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||||
from lerobot.common.robot_devices.utils import busy_wait
|
from lerobot.common.robot_devices.utils import busy_wait
|
||||||
@@ -39,20 +37,21 @@ from lerobot.common.utils.utils import (
|
|||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service
|
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service
|
||||||
from lerobot.scripts.server.buffer import (
|
from lerobot.scripts.server.buffer import Transition
|
||||||
Transition,
|
|
||||||
bytes_to_state_dict,
|
|
||||||
move_state_dict_to_device,
|
|
||||||
move_transition_to_device,
|
|
||||||
python_object_to_bytes,
|
|
||||||
transitions_to_bytes,
|
|
||||||
)
|
|
||||||
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
||||||
from lerobot.scripts.server.network_utils import (
|
from lerobot.scripts.server.network_utils import (
|
||||||
|
bytes_to_state_dict,
|
||||||
|
python_object_to_bytes,
|
||||||
receive_bytes_in_chunks,
|
receive_bytes_in_chunks,
|
||||||
send_bytes_in_chunks,
|
send_bytes_in_chunks,
|
||||||
|
transitions_to_bytes,
|
||||||
|
)
|
||||||
|
from lerobot.scripts.server.utils import (
|
||||||
|
get_last_item_from_queue,
|
||||||
|
move_state_dict_to_device,
|
||||||
|
move_transition_to_device,
|
||||||
|
setup_process_handlers,
|
||||||
)
|
)
|
||||||
from lerobot.scripts.server.utils import get_last_item_from_queue, setup_process_handlers
|
|
||||||
|
|
||||||
ACTOR_SHUTDOWN_TIMEOUT = 30
|
ACTOR_SHUTDOWN_TIMEOUT = 30
|
||||||
|
|
||||||
@@ -134,21 +133,8 @@ def actor_cli(cfg: TrainPipelineConfig):
|
|||||||
interactions_process.start()
|
interactions_process.start()
|
||||||
receive_policy_process.start()
|
receive_policy_process.start()
|
||||||
|
|
||||||
# HACK: FOR MANISKILL we do not have a reward classifier
|
|
||||||
# TODO: Remove this once we merge into main
|
|
||||||
reward_classifier = None
|
|
||||||
# if (
|
|
||||||
# cfg.env.reward_classifier["pretrained_path"] is not None
|
|
||||||
# and cfg.env.reward_classifier["config_path"] is not None
|
|
||||||
# ):
|
|
||||||
# reward_classifier = get_classifier(
|
|
||||||
# pretrained_path=cfg.env.reward_classifier["pretrained_path"],
|
|
||||||
# config_path=cfg.env.reward_classifier["config_path"],
|
|
||||||
# )
|
|
||||||
|
|
||||||
act_with_policy(
|
act_with_policy(
|
||||||
cfg=cfg,
|
cfg=cfg,
|
||||||
reward_classifier=reward_classifier,
|
|
||||||
shutdown_event=shutdown_event,
|
shutdown_event=shutdown_event,
|
||||||
parameters_queue=parameters_queue,
|
parameters_queue=parameters_queue,
|
||||||
transitions_queue=transitions_queue,
|
transitions_queue=transitions_queue,
|
||||||
@@ -183,7 +169,6 @@ def actor_cli(cfg: TrainPipelineConfig):
|
|||||||
|
|
||||||
def act_with_policy(
|
def act_with_policy(
|
||||||
cfg: TrainPipelineConfig,
|
cfg: TrainPipelineConfig,
|
||||||
reward_classifier: nn.Module,
|
|
||||||
shutdown_event: any, # Event,
|
shutdown_event: any, # Event,
|
||||||
parameters_queue: Queue,
|
parameters_queue: Queue,
|
||||||
transitions_queue: Queue,
|
transitions_queue: Queue,
|
||||||
@@ -197,7 +182,6 @@ def act_with_policy(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: Configuration settings for the interaction process.
|
cfg: Configuration settings for the interaction process.
|
||||||
reward_classifier: Reward classifier to use for the interaction process.
|
|
||||||
shutdown_event: Event to check if the process should shutdown.
|
shutdown_event: Event to check if the process should shutdown.
|
||||||
parameters_queue: Queue to receive updated network parameters from the learner.
|
parameters_queue: Queue to receive updated network parameters from the learner.
|
||||||
transitions_queue: Queue to send transitions to the learner.
|
transitions_queue: Queue to send transitions to the learner.
|
||||||
@@ -262,16 +246,10 @@ def act_with_policy(
|
|||||||
|
|
||||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||||
|
|
||||||
next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy())
|
|
||||||
else:
|
else:
|
||||||
# TODO (azouitine): Make a custom space for torch tensor
|
|
||||||
action = online_env.action_space.sample()
|
action = online_env.action_space.sample()
|
||||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
|
||||||
|
|
||||||
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
|
next_obs, reward, done, truncated, info = online_env.step(action)
|
||||||
action = (
|
|
||||||
torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0)
|
|
||||||
)
|
|
||||||
|
|
||||||
sum_reward_episode += float(reward)
|
sum_reward_episode += float(reward)
|
||||||
# Increment total steps counter for intervention rate
|
# Increment total steps counter for intervention rate
|
||||||
@@ -286,11 +264,6 @@ def act_with_policy(
|
|||||||
# Increment intervention steps counter
|
# Increment intervention steps counter
|
||||||
episode_intervention_steps += 1
|
episode_intervention_steps += 1
|
||||||
|
|
||||||
# Check for NaN values in observations
|
|
||||||
for key, tensor in obs.items():
|
|
||||||
if torch.isnan(tensor).any():
|
|
||||||
logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}")
|
|
||||||
|
|
||||||
list_transition_to_send_to_learner.append(
|
list_transition_to_send_to_learner.append(
|
||||||
Transition(
|
Transition(
|
||||||
state=obs,
|
state=obs,
|
||||||
|
|||||||
@@ -15,26 +15,15 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import io
|
|
||||||
import pickle # nosec B403: Safe usage of pickle
|
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from typing import Any, Callable, Optional, Sequence, TypedDict
|
from typing import Callable, Optional, Sequence, TypedDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.scripts.server.utils import Transition
|
||||||
|
|
||||||
class Transition(TypedDict):
|
|
||||||
state: dict[str, torch.Tensor]
|
|
||||||
action: torch.Tensor
|
|
||||||
reward: float
|
|
||||||
next_state: dict[str, torch.Tensor]
|
|
||||||
done: bool
|
|
||||||
truncated: bool
|
|
||||||
complementary_info: dict[str, torch.Tensor | float | int] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class BatchTransition(TypedDict):
|
class BatchTransition(TypedDict):
|
||||||
@@ -47,103 +36,6 @@ class BatchTransition(TypedDict):
|
|||||||
complementary_info: dict[str, torch.Tensor | float | int] | None = None
|
complementary_info: dict[str, torch.Tensor | float | int] | None = None
|
||||||
|
|
||||||
|
|
||||||
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
|
|
||||||
device = torch.device(device)
|
|
||||||
non_blocking = device.type == "cuda"
|
|
||||||
|
|
||||||
# Move state tensors to device
|
|
||||||
transition["state"] = {
|
|
||||||
key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items()
|
|
||||||
}
|
|
||||||
|
|
||||||
# Move action to device
|
|
||||||
transition["action"] = transition["action"].to(device, non_blocking=non_blocking)
|
|
||||||
|
|
||||||
# Move reward and done if they are tensors
|
|
||||||
if isinstance(transition["reward"], torch.Tensor):
|
|
||||||
transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking)
|
|
||||||
|
|
||||||
if isinstance(transition["done"], torch.Tensor):
|
|
||||||
transition["done"] = transition["done"].to(device, non_blocking=non_blocking)
|
|
||||||
|
|
||||||
if isinstance(transition["truncated"], torch.Tensor):
|
|
||||||
transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking)
|
|
||||||
|
|
||||||
# Move next_state tensors to device
|
|
||||||
transition["next_state"] = {
|
|
||||||
key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items()
|
|
||||||
}
|
|
||||||
|
|
||||||
# Move complementary_info tensors if present
|
|
||||||
if transition.get("complementary_info") is not None:
|
|
||||||
for key, val in transition["complementary_info"].items():
|
|
||||||
if isinstance(val, torch.Tensor):
|
|
||||||
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
|
|
||||||
elif isinstance(val, (int, float, bool)):
|
|
||||||
transition["complementary_info"][key] = torch.tensor(val, device=device)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
|
|
||||||
return transition
|
|
||||||
|
|
||||||
|
|
||||||
def move_state_dict_to_device(state_dict, device="cpu"):
|
|
||||||
"""
|
|
||||||
Recursively move all tensors in a (potentially) nested
|
|
||||||
dict/list/tuple structure to the CPU.
|
|
||||||
"""
|
|
||||||
if isinstance(state_dict, torch.Tensor):
|
|
||||||
return state_dict.to(device)
|
|
||||||
elif isinstance(state_dict, dict):
|
|
||||||
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
|
|
||||||
elif isinstance(state_dict, list):
|
|
||||||
return [move_state_dict_to_device(v, device=device) for v in state_dict]
|
|
||||||
elif isinstance(state_dict, tuple):
|
|
||||||
return tuple(move_state_dict_to_device(v, device=device) for v in state_dict)
|
|
||||||
else:
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
|
|
||||||
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
|
|
||||||
"""Convert model state dict to flat array for transmission"""
|
|
||||||
buffer = io.BytesIO()
|
|
||||||
|
|
||||||
torch.save(state_dict, buffer)
|
|
||||||
|
|
||||||
return buffer.getvalue()
|
|
||||||
|
|
||||||
|
|
||||||
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
|
|
||||||
buffer = io.BytesIO(buffer)
|
|
||||||
buffer.seek(0)
|
|
||||||
return torch.load(buffer) # nosec B614: Safe usage of torch.load
|
|
||||||
|
|
||||||
|
|
||||||
def python_object_to_bytes(python_object: Any) -> bytes:
|
|
||||||
return pickle.dumps(python_object)
|
|
||||||
|
|
||||||
|
|
||||||
def bytes_to_python_object(buffer: bytes) -> Any:
|
|
||||||
buffer = io.BytesIO(buffer)
|
|
||||||
buffer.seek(0)
|
|
||||||
obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load
|
|
||||||
# Add validation checks here
|
|
||||||
return obj
|
|
||||||
|
|
||||||
|
|
||||||
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
|
|
||||||
buffer = io.BytesIO(buffer)
|
|
||||||
buffer.seek(0)
|
|
||||||
transitions = torch.load(buffer) # nosec B614: Safe usage of torch.load
|
|
||||||
# Add validation checks here
|
|
||||||
return transitions
|
|
||||||
|
|
||||||
|
|
||||||
def transitions_to_bytes(transitions: list[Transition]) -> bytes:
|
|
||||||
buffer = io.BytesIO()
|
|
||||||
torch.save(transitions, buffer)
|
|
||||||
return buffer.getvalue()
|
|
||||||
|
|
||||||
|
|
||||||
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
|
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Perform a per-image random crop over a batch of images in a vectorized way.
|
Perform a per-image random crop over a batch of images in a vectorized way.
|
||||||
@@ -514,7 +406,6 @@ class ReplayBuffer:
|
|||||||
device: str = "cuda:0",
|
device: str = "cuda:0",
|
||||||
state_keys: Optional[Sequence[str]] = None,
|
state_keys: Optional[Sequence[str]] = None,
|
||||||
capacity: Optional[int] = None,
|
capacity: Optional[int] = None,
|
||||||
action_mask: Optional[Sequence[int]] = None,
|
|
||||||
image_augmentation_function: Optional[Callable] = None,
|
image_augmentation_function: Optional[Callable] = None,
|
||||||
use_drq: bool = True,
|
use_drq: bool = True,
|
||||||
storage_device: str = "cpu",
|
storage_device: str = "cpu",
|
||||||
@@ -566,13 +457,6 @@ class ReplayBuffer:
|
|||||||
first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
|
first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
|
||||||
first_action = first_transition["action"].to(device)
|
first_action = first_transition["action"].to(device)
|
||||||
|
|
||||||
# Apply action mask/delta if needed
|
|
||||||
if action_mask is not None:
|
|
||||||
if first_action.dim() == 1:
|
|
||||||
first_action = first_action[action_mask]
|
|
||||||
else:
|
|
||||||
first_action = first_action[:, action_mask]
|
|
||||||
|
|
||||||
# Get complementary info if available
|
# Get complementary info if available
|
||||||
first_complementary_info = None
|
first_complementary_info = None
|
||||||
if (
|
if (
|
||||||
@@ -597,8 +481,6 @@ class ReplayBuffer:
|
|||||||
data[k] = v.to(storage_device)
|
data[k] = v.to(storage_device)
|
||||||
|
|
||||||
action = data["action"]
|
action = data["action"]
|
||||||
if action_mask is not None:
|
|
||||||
action = action[action_mask] if action.dim() == 1 else action[:, action_mask]
|
|
||||||
|
|
||||||
replay_buffer.add(
|
replay_buffer.add(
|
||||||
state=data["state"],
|
state=data["state"],
|
||||||
|
|||||||
@@ -524,7 +524,9 @@ def teleoperate_inverse_kinematics_with_leader(robot, fps=10):
|
|||||||
leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
|
leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
|
||||||
|
|
||||||
desired_ee_pos = leader_ee
|
desired_ee_pos = leader_ee
|
||||||
target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True)
|
target_joint_state = kinematics.ik(
|
||||||
|
joint_positions, desired_ee_pos, position_only=True, fk_func=kinematics.fk_gripper_tip
|
||||||
|
)
|
||||||
robot.send_action(torch.from_numpy(target_joint_state))
|
robot.send_action(torch.from_numpy(target_joint_state))
|
||||||
logging.info(f"Leader EE: {leader_ee[:3, 3]}, Follower EE: {ee_pos[:3, 3]}")
|
logging.info(f"Leader EE: {leader_ee[:3, 3]}, Follower EE: {ee_pos[:3, 3]}")
|
||||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||||
@@ -544,6 +546,8 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
|
|||||||
initial_leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
|
initial_leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
|
||||||
|
|
||||||
desired_ee_pos = np.diag(np.ones(4))
|
desired_ee_pos = np.diag(np.ones(4))
|
||||||
|
joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||||
|
fixed_ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||||
|
|
||||||
while time.perf_counter() - timestep < 60.0:
|
while time.perf_counter() - timestep < 60.0:
|
||||||
loop_start_time = time.perf_counter()
|
loop_start_time = time.perf_counter()
|
||||||
@@ -561,25 +565,26 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
|
|||||||
# Calculate delta between leader and follower end-effectors
|
# Calculate delta between leader and follower end-effectors
|
||||||
# Scaling factor can be adjusted for sensitivity
|
# Scaling factor can be adjusted for sensitivity
|
||||||
scaling_factor = 1.0
|
scaling_factor = 1.0
|
||||||
ee_delta = (leader_ee - initial_leader_ee) * scaling_factor
|
ee_delta = -np.clip((leader_ee - initial_leader_ee) * scaling_factor, -0.05, 0.05)
|
||||||
|
|
||||||
# Apply delta to current position
|
# Apply delta to current position
|
||||||
desired_ee_pos[0, 3] = current_ee_pos[0, 3] + ee_delta[0, 3]
|
desired_ee_pos[0, 3] = fixed_ee_pos[0, 3] # current_ee_pos[0, 3] + ee_delta[0, 3] * 0
|
||||||
desired_ee_pos[1, 3] = current_ee_pos[1, 3] + ee_delta[1, 3]
|
desired_ee_pos[1, 3] = fixed_ee_pos[1, 3] # current_ee_pos[1, 3] + ee_delta[1, 3] * 0
|
||||||
desired_ee_pos[2, 3] = current_ee_pos[2, 3] + ee_delta[2, 3]
|
desired_ee_pos[2, 3] = current_ee_pos[2, 3] - ee_delta[2, 3]
|
||||||
|
|
||||||
if np.any(np.abs(ee_delta[:3, 3]) > 0.01):
|
# Compute joint targets via inverse kinematics
|
||||||
# Compute joint targets via inverse kinematics
|
target_joint_state = kinematics.ik(
|
||||||
target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True)
|
joint_positions, desired_ee_pos, position_only=True, fk_func=kinematics.fk_gripper_tip
|
||||||
|
)
|
||||||
|
|
||||||
initial_leader_ee = leader_ee.copy()
|
initial_leader_ee = leader_ee.copy()
|
||||||
|
|
||||||
# Send command to robot
|
# Send command to robot
|
||||||
robot.send_action(torch.from_numpy(target_joint_state))
|
robot.send_action(torch.from_numpy(target_joint_state))
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
logging.info(f"Current EE: {current_ee_pos[:3, 3]}, Desired EE: {desired_ee_pos[:3, 3]}")
|
logging.info(f"Current EE: {current_ee_pos[:3, 3]}, Desired EE: {desired_ee_pos[:3, 3]}")
|
||||||
logging.info(f"Delta EE: {ee_delta[:3, 3]}")
|
logging.info(f"Delta EE: {ee_delta[:3, 3]}")
|
||||||
|
|
||||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||||
|
|
||||||
@@ -715,8 +720,8 @@ if __name__ == "__main__":
|
|||||||
"gamepad",
|
"gamepad",
|
||||||
"keyboard_gym",
|
"keyboard_gym",
|
||||||
"gamepad_gym",
|
"gamepad_gym",
|
||||||
|
"leader_delta",
|
||||||
"leader",
|
"leader",
|
||||||
"leader_abs",
|
|
||||||
],
|
],
|
||||||
help="Control mode to use",
|
help="Control mode to use",
|
||||||
)
|
)
|
||||||
@@ -768,11 +773,11 @@ if __name__ == "__main__":
|
|||||||
env = make_robot_env(cfg, robot)
|
env = make_robot_env(cfg, robot)
|
||||||
teleoperate_gym_env(env, controller, fps=cfg.fps)
|
teleoperate_gym_env(env, controller, fps=cfg.fps)
|
||||||
|
|
||||||
elif args.mode == "leader":
|
elif args.mode == "leader_delta":
|
||||||
# Leader-follower modes don't use controllers
|
# Leader-follower modes don't use controllers
|
||||||
teleoperate_delta_inverse_kinematics_with_leader(robot)
|
teleoperate_delta_inverse_kinematics_with_leader(robot)
|
||||||
|
|
||||||
elif args.mode == "leader_abs":
|
elif args.mode == "leader":
|
||||||
teleoperate_inverse_kinematics_with_leader(robot)
|
teleoperate_inverse_kinematics_with_leader(robot)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from lerobot.configs import parser
|
|||||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||||
|
|
||||||
follower_port = "/dev/tty.usbmodem58760431631"
|
follower_port = "/dev/tty.usbmodem58760431631"
|
||||||
leader_port = "/dev/tty.usbmodem58760433331"
|
leader_port = "/dev/tty.usbmodem585A0077921"
|
||||||
|
|
||||||
|
|
||||||
def find_joint_bounds(
|
def find_joint_bounds(
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -23,8 +23,6 @@ from pathlib import Path
|
|||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
|
||||||
# Import generated stubs
|
|
||||||
import hilserl_pb2_grpc # type: ignore
|
import hilserl_pb2_grpc # type: ignore
|
||||||
import torch
|
import torch
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
@@ -39,8 +37,6 @@ from lerobot.common.constants import (
|
|||||||
TRAINING_STATE_DIR,
|
TRAINING_STATE_DIR,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
|
|
||||||
# TODO: Remove the import of maniskill
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||||
@@ -62,16 +58,17 @@ from lerobot.common.utils.wandb_utils import WandBLogger
|
|||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from lerobot.scripts.server import learner_service
|
from lerobot.scripts.server import learner_service
|
||||||
from lerobot.scripts.server.buffer import (
|
from lerobot.scripts.server.buffer import ReplayBuffer, concatenate_batch_transitions
|
||||||
ReplayBuffer,
|
from lerobot.scripts.server.network_utils import (
|
||||||
bytes_to_python_object,
|
bytes_to_python_object,
|
||||||
bytes_to_transitions,
|
bytes_to_transitions,
|
||||||
concatenate_batch_transitions,
|
|
||||||
move_state_dict_to_device,
|
|
||||||
move_transition_to_device,
|
|
||||||
state_to_bytes,
|
state_to_bytes,
|
||||||
)
|
)
|
||||||
from lerobot.scripts.server.utils import setup_process_handlers
|
from lerobot.scripts.server.utils import (
|
||||||
|
move_state_dict_to_device,
|
||||||
|
move_transition_to_device,
|
||||||
|
setup_process_handlers,
|
||||||
|
)
|
||||||
|
|
||||||
LOG_PREFIX = "[LEARNER]"
|
LOG_PREFIX = "[LEARNER]"
|
||||||
|
|
||||||
@@ -307,17 +304,10 @@ def add_actor_information_and_train(
|
|||||||
offline_replay_buffer = None
|
offline_replay_buffer = None
|
||||||
|
|
||||||
if cfg.dataset is not None:
|
if cfg.dataset is not None:
|
||||||
active_action_dims = None
|
|
||||||
# TODO: FIX THIS
|
|
||||||
if cfg.env.wrapper.joint_masking_action_space is not None:
|
|
||||||
active_action_dims = [
|
|
||||||
i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask
|
|
||||||
]
|
|
||||||
offline_replay_buffer = initialize_offline_replay_buffer(
|
offline_replay_buffer = initialize_offline_replay_buffer(
|
||||||
cfg=cfg,
|
cfg=cfg,
|
||||||
device=device,
|
device=device,
|
||||||
storage_device=storage_device,
|
storage_device=storage_device,
|
||||||
active_action_dims=active_action_dims,
|
|
||||||
)
|
)
|
||||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||||
|
|
||||||
@@ -342,7 +332,6 @@ def add_actor_information_and_train(
|
|||||||
break
|
break
|
||||||
|
|
||||||
# Process all available transitions to the replay buffer, send by the actor server
|
# Process all available transitions to the replay buffer, send by the actor server
|
||||||
logging.debug("[LEARNER] Waiting for transitions")
|
|
||||||
process_transitions(
|
process_transitions(
|
||||||
transition_queue=transition_queue,
|
transition_queue=transition_queue,
|
||||||
replay_buffer=replay_buffer,
|
replay_buffer=replay_buffer,
|
||||||
@@ -351,35 +340,29 @@ def add_actor_information_and_train(
|
|||||||
dataset_repo_id=dataset_repo_id,
|
dataset_repo_id=dataset_repo_id,
|
||||||
shutdown_event=shutdown_event,
|
shutdown_event=shutdown_event,
|
||||||
)
|
)
|
||||||
logging.debug("[LEARNER] Received transitions")
|
|
||||||
|
|
||||||
# Process all available interaction messages sent by the actor server
|
# Process all available interaction messages sent by the actor server
|
||||||
logging.debug("[LEARNER] Waiting for interactions")
|
|
||||||
interaction_message = process_interaction_messages(
|
interaction_message = process_interaction_messages(
|
||||||
interaction_message_queue=interaction_message_queue,
|
interaction_message_queue=interaction_message_queue,
|
||||||
interaction_step_shift=interaction_step_shift,
|
interaction_step_shift=interaction_step_shift,
|
||||||
wandb_logger=wandb_logger,
|
wandb_logger=wandb_logger,
|
||||||
shutdown_event=shutdown_event,
|
shutdown_event=shutdown_event,
|
||||||
)
|
)
|
||||||
logging.debug("[LEARNER] Received interactions")
|
|
||||||
|
|
||||||
# Wait until the replay buffer has enough samples to start training
|
# Wait until the replay buffer has enough samples to start training
|
||||||
if len(replay_buffer) < online_step_before_learning:
|
if len(replay_buffer) < online_step_before_learning:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if online_iterator is None:
|
if online_iterator is None:
|
||||||
logging.debug("[LEARNER] Initializing online replay buffer iterator")
|
|
||||||
online_iterator = replay_buffer.get_iterator(
|
online_iterator = replay_buffer.get_iterator(
|
||||||
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
|
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
|
||||||
)
|
)
|
||||||
|
|
||||||
if offline_replay_buffer is not None and offline_iterator is None:
|
if offline_replay_buffer is not None and offline_iterator is None:
|
||||||
logging.debug("[LEARNER] Initializing offline replay buffer iterator")
|
|
||||||
offline_iterator = offline_replay_buffer.get_iterator(
|
offline_iterator = offline_replay_buffer.get_iterator(
|
||||||
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
|
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.debug("[LEARNER] Starting optimization loop")
|
|
||||||
time_for_one_optimization_step = time.time()
|
time_for_one_optimization_step = time.time()
|
||||||
for _ in range(utd_ratio - 1):
|
for _ in range(utd_ratio - 1):
|
||||||
# Sample from the iterators
|
# Sample from the iterators
|
||||||
@@ -967,7 +950,6 @@ def initialize_offline_replay_buffer(
|
|||||||
cfg: TrainPipelineConfig,
|
cfg: TrainPipelineConfig,
|
||||||
device: str,
|
device: str,
|
||||||
storage_device: str,
|
storage_device: str,
|
||||||
active_action_dims: list[int] | None = None,
|
|
||||||
) -> ReplayBuffer:
|
) -> ReplayBuffer:
|
||||||
"""
|
"""
|
||||||
Initialize an offline replay buffer from a dataset.
|
Initialize an offline replay buffer from a dataset.
|
||||||
@@ -976,7 +958,6 @@ def initialize_offline_replay_buffer(
|
|||||||
cfg (TrainPipelineConfig): Training configuration
|
cfg (TrainPipelineConfig): Training configuration
|
||||||
device (str): Device to store tensors on
|
device (str): Device to store tensors on
|
||||||
storage_device (str): Device for storage optimization
|
storage_device (str): Device for storage optimization
|
||||||
active_action_dims (list[int] | None): Active action dimensions for masking
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ReplayBuffer: Initialized offline replay buffer
|
ReplayBuffer: Initialized offline replay buffer
|
||||||
@@ -997,7 +978,6 @@ def initialize_offline_replay_buffer(
|
|||||||
offline_dataset,
|
offline_dataset,
|
||||||
device=device,
|
device=device,
|
||||||
state_keys=cfg.policy.input_features.keys(),
|
state_keys=cfg.policy.input_features.keys(),
|
||||||
action_mask=active_action_dims,
|
|
||||||
storage_device=storage_device,
|
storage_device=storage_device,
|
||||||
optimize_memory=True,
|
optimize_memory=True,
|
||||||
capacity=cfg.policy.offline_buffer_capacity,
|
capacity=cfg.policy.offline_buffer_capacity,
|
||||||
@@ -1096,44 +1076,6 @@ def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
|
|||||||
parameters_queue.put(state_bytes)
|
parameters_queue.put(state_bytes)
|
||||||
|
|
||||||
|
|
||||||
def check_weight_gradients(module: nn.Module) -> dict[str, bool]:
|
|
||||||
"""
|
|
||||||
Checks whether each parameter in the module has a gradient.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
module (nn.Module): A PyTorch module whose parameters will be inspected.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict[str, bool]: A dictionary where each key is the parameter name and the value is
|
|
||||||
True if the parameter has an associated gradient (i.e. .grad is not None),
|
|
||||||
otherwise False.
|
|
||||||
"""
|
|
||||||
grad_status = {}
|
|
||||||
for name, param in module.named_parameters():
|
|
||||||
grad_status[name] = param.grad is not None
|
|
||||||
return grad_status
|
|
||||||
|
|
||||||
|
|
||||||
def get_overlapping_parameters(model: nn.Module, grad_status: dict[str, bool]) -> dict[str, bool]:
|
|
||||||
"""
|
|
||||||
Returns a dictionary of parameters (from actor) that also exist in the grad_status dictionary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
actor (nn.Module): The actor model.
|
|
||||||
grad_status (dict[str, bool]): A dictionary where keys are parameter names and values indicate
|
|
||||||
whether each parameter has a gradient.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict[str, bool]: A dictionary containing only the overlapping parameter names and their gradient status.
|
|
||||||
"""
|
|
||||||
# Get actor parameter names as a set.
|
|
||||||
model_param_names = {name for name, _ in model.named_parameters()}
|
|
||||||
|
|
||||||
# Intersect parameter names between actor and grad_status.
|
|
||||||
overlapping = {name: grad_status[name] for name in grad_status if name in model_param_names}
|
|
||||||
return overlapping
|
|
||||||
|
|
||||||
|
|
||||||
def process_interaction_message(
|
def process_interaction_message(
|
||||||
message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None
|
message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -1,221 +0,0 @@
|
|||||||
from typing import Any
|
|
||||||
|
|
||||||
import einops
|
|
||||||
import gymnasium as gym
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from mani_skill.utils.wrappers.record import RecordEpisode
|
|
||||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
|
||||||
|
|
||||||
from lerobot.common.envs.configs import ManiskillEnvConfig
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_maniskill_observation(
|
|
||||||
observations: dict[str, np.ndarray],
|
|
||||||
) -> dict[str, torch.Tensor]:
|
|
||||||
"""Convert environment observation to LeRobot format observation.
|
|
||||||
Args:
|
|
||||||
observation: Dictionary of observation batches from a Gym vector environment.
|
|
||||||
Returns:
|
|
||||||
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
|
||||||
"""
|
|
||||||
# map to expected inputs for the policy
|
|
||||||
return_observations = {}
|
|
||||||
# TODO: You have to merge all tensors from agent key and extra key
|
|
||||||
# You don't keep sensor param key in the observation
|
|
||||||
# And you keep sensor data rgb
|
|
||||||
q_pos = observations["agent"]["qpos"]
|
|
||||||
q_vel = observations["agent"]["qvel"]
|
|
||||||
tcp_pos = observations["extra"]["tcp_pose"]
|
|
||||||
img = observations["sensor_data"]["base_camera"]["rgb"]
|
|
||||||
|
|
||||||
_, h, w, c = img.shape
|
|
||||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
|
||||||
|
|
||||||
# sanity check that images are uint8
|
|
||||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
|
||||||
|
|
||||||
# convert to channel first of type float32 in range [0,1]
|
|
||||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
|
||||||
img = img.type(torch.float32)
|
|
||||||
img /= 255
|
|
||||||
|
|
||||||
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
|
|
||||||
|
|
||||||
return_observations["observation.image"] = img
|
|
||||||
return_observations["observation.state"] = state
|
|
||||||
return return_observations
|
|
||||||
|
|
||||||
|
|
||||||
class ManiSkillObservationWrapper(gym.ObservationWrapper):
|
|
||||||
def __init__(self, env, device: torch.device = "cuda"):
|
|
||||||
super().__init__(env)
|
|
||||||
if isinstance(device, str):
|
|
||||||
device = torch.device(device)
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
def observation(self, observation):
|
|
||||||
observation = preprocess_maniskill_observation(observation)
|
|
||||||
observation = {k: v.to(self.device) for k, v in observation.items()}
|
|
||||||
return observation
|
|
||||||
|
|
||||||
|
|
||||||
class ManiSkillCompat(gym.Wrapper):
|
|
||||||
def __init__(self, env):
|
|
||||||
super().__init__(env)
|
|
||||||
new_action_space_shape = env.action_space.shape[-1]
|
|
||||||
new_low = np.squeeze(env.action_space.low, axis=0)
|
|
||||||
new_high = np.squeeze(env.action_space.high, axis=0)
|
|
||||||
self.action_space = gym.spaces.Box(low=new_low, high=new_high, shape=(new_action_space_shape,))
|
|
||||||
|
|
||||||
def reset(
|
|
||||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
|
||||||
) -> tuple[Any, dict[str, Any]]:
|
|
||||||
options = {}
|
|
||||||
return super().reset(seed=seed, options=options)
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
||||||
reward = reward.item()
|
|
||||||
terminated = terminated.item()
|
|
||||||
truncated = truncated.item()
|
|
||||||
return obs, reward, terminated, truncated, info
|
|
||||||
|
|
||||||
|
|
||||||
class ManiSkillActionWrapper(gym.ActionWrapper):
|
|
||||||
def __init__(self, env):
|
|
||||||
super().__init__(env)
|
|
||||||
self.action_space = gym.spaces.Tuple(spaces=(env.action_space, gym.spaces.Discrete(2)))
|
|
||||||
|
|
||||||
def action(self, action):
|
|
||||||
action, telop = action
|
|
||||||
return action
|
|
||||||
|
|
||||||
|
|
||||||
class ManiSkillMultiplyActionWrapper(gym.Wrapper):
|
|
||||||
def __init__(self, env, multiply_factor: float = 1):
|
|
||||||
super().__init__(env)
|
|
||||||
self.multiply_factor = multiply_factor
|
|
||||||
action_space_agent: gym.spaces.Box = env.action_space[0]
|
|
||||||
action_space_agent.low = action_space_agent.low * multiply_factor
|
|
||||||
action_space_agent.high = action_space_agent.high * multiply_factor
|
|
||||||
self.action_space = gym.spaces.Tuple(spaces=(action_space_agent, gym.spaces.Discrete(2)))
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
if isinstance(action, tuple):
|
|
||||||
action, telop = action
|
|
||||||
else:
|
|
||||||
telop = 0
|
|
||||||
action = action / self.multiply_factor
|
|
||||||
obs, reward, terminated, truncated, info = self.env.step((action, telop))
|
|
||||||
return obs, reward, terminated, truncated, info
|
|
||||||
|
|
||||||
|
|
||||||
class BatchCompatibleWrapper(gym.ObservationWrapper):
|
|
||||||
"""Ensures observations are batch-compatible by adding a batch dimension if necessary."""
|
|
||||||
|
|
||||||
def __init__(self, env):
|
|
||||||
super().__init__(env)
|
|
||||||
|
|
||||||
def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
||||||
for key in observation:
|
|
||||||
if "image" in key and observation[key].dim() == 3:
|
|
||||||
observation[key] = observation[key].unsqueeze(0)
|
|
||||||
if "state" in key and observation[key].dim() == 1:
|
|
||||||
observation[key] = observation[key].unsqueeze(0)
|
|
||||||
return observation
|
|
||||||
|
|
||||||
|
|
||||||
class TimeLimitWrapper(gym.Wrapper):
|
|
||||||
"""Adds a time limit to the environment based on fps and control_time."""
|
|
||||||
|
|
||||||
def __init__(self, env, control_time_s, fps):
|
|
||||||
super().__init__(env)
|
|
||||||
self.control_time_s = control_time_s
|
|
||||||
self.fps = fps
|
|
||||||
self.max_episode_steps = int(self.control_time_s * self.fps)
|
|
||||||
self.current_step = 0
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
||||||
self.current_step += 1
|
|
||||||
|
|
||||||
if self.current_step >= self.max_episode_steps:
|
|
||||||
terminated = True
|
|
||||||
|
|
||||||
return obs, reward, terminated, truncated, info
|
|
||||||
|
|
||||||
def reset(self, *, seed=None, options=None):
|
|
||||||
self.current_step = 0
|
|
||||||
return super().reset(seed=seed, options=options)
|
|
||||||
|
|
||||||
|
|
||||||
class ManiskillMockGripperWrapper(gym.Wrapper):
|
|
||||||
def __init__(self, env, nb_discrete_actions: int = 3):
|
|
||||||
super().__init__(env)
|
|
||||||
new_shape = env.action_space[0].shape[0] + 1
|
|
||||||
new_low = np.concatenate([env.action_space[0].low, [0]])
|
|
||||||
new_high = np.concatenate([env.action_space[0].high, [nb_discrete_actions - 1]])
|
|
||||||
action_space_agent = gym.spaces.Box(low=new_low, high=new_high, shape=(new_shape,))
|
|
||||||
self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1]))
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
if isinstance(action, tuple):
|
|
||||||
action_agent, telop_action = action
|
|
||||||
else:
|
|
||||||
telop_action = 0
|
|
||||||
action_agent = action
|
|
||||||
real_action = action_agent[:-1]
|
|
||||||
final_action = (real_action, telop_action)
|
|
||||||
obs, reward, terminated, truncated, info = self.env.step(final_action)
|
|
||||||
return obs, reward, terminated, truncated, info
|
|
||||||
|
|
||||||
|
|
||||||
def make_maniskill(
|
|
||||||
cfg: ManiskillEnvConfig,
|
|
||||||
n_envs: int | None = None,
|
|
||||||
) -> gym.Env:
|
|
||||||
"""
|
|
||||||
Factory function to create a ManiSkill environment with standard wrappers.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg: Configuration for the ManiSkill environment
|
|
||||||
n_envs: Number of parallel environments
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A wrapped ManiSkill environment
|
|
||||||
"""
|
|
||||||
env = gym.make(
|
|
||||||
cfg.task,
|
|
||||||
obs_mode=cfg.obs_type,
|
|
||||||
control_mode=cfg.control_mode,
|
|
||||||
render_mode=cfg.render_mode,
|
|
||||||
sensor_configs={"width": cfg.image_size, "height": cfg.image_size},
|
|
||||||
num_envs=n_envs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add video recording if enabled
|
|
||||||
if cfg.video_record.enabled:
|
|
||||||
env = RecordEpisode(
|
|
||||||
env,
|
|
||||||
output_dir=cfg.video_record.record_dir,
|
|
||||||
save_trajectory=True,
|
|
||||||
trajectory_name=cfg.video_record.trajectory_name,
|
|
||||||
save_video=True,
|
|
||||||
video_fps=30,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add observation and image processing
|
|
||||||
env = ManiSkillObservationWrapper(env, device=cfg.device)
|
|
||||||
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
|
|
||||||
env._max_episode_steps = env.max_episode_steps = cfg.episode_length
|
|
||||||
env.unwrapped.metadata["render_fps"] = cfg.fps
|
|
||||||
|
|
||||||
# Add compatibility wrappers
|
|
||||||
env = ManiSkillCompat(env)
|
|
||||||
env = ManiSkillActionWrapper(env)
|
|
||||||
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control
|
|
||||||
if cfg.mock_gripper:
|
|
||||||
env = ManiskillMockGripperWrapper(env, nb_discrete_actions=3)
|
|
||||||
|
|
||||||
return env
|
|
||||||
@@ -17,10 +17,14 @@
|
|||||||
|
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
|
import pickle # nosec B403: Safe usage for internal serialization only
|
||||||
from multiprocessing import Event, Queue
|
from multiprocessing import Event, Queue
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from lerobot.scripts.server import hilserl_pb2
|
from lerobot.scripts.server import hilserl_pb2
|
||||||
|
from lerobot.scripts.server.utils import Transition
|
||||||
|
|
||||||
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
|
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
|
||||||
|
|
||||||
@@ -60,7 +64,7 @@ def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = ""
|
|||||||
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
|
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
|
||||||
|
|
||||||
|
|
||||||
def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""):
|
def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""): # type: ignore
|
||||||
bytes_buffer = io.BytesIO()
|
bytes_buffer = io.BytesIO()
|
||||||
step = 0
|
step = 0
|
||||||
|
|
||||||
@@ -93,3 +97,44 @@ def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_p
|
|||||||
step = 0
|
step = 0
|
||||||
|
|
||||||
logging.debug(f"{log_prefix} Queue updated")
|
logging.debug(f"{log_prefix} Queue updated")
|
||||||
|
|
||||||
|
|
||||||
|
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
|
||||||
|
"""Convert model state dict to flat array for transmission"""
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
|
||||||
|
torch.save(state_dict, buffer)
|
||||||
|
|
||||||
|
return buffer.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
|
||||||
|
buffer = io.BytesIO(buffer)
|
||||||
|
buffer.seek(0)
|
||||||
|
return torch.load(buffer) # nosec B614: Safe usage of torch.load
|
||||||
|
|
||||||
|
|
||||||
|
def python_object_to_bytes(python_object: Any) -> bytes:
|
||||||
|
return pickle.dumps(python_object)
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_to_python_object(buffer: bytes) -> Any:
|
||||||
|
buffer = io.BytesIO(buffer)
|
||||||
|
buffer.seek(0)
|
||||||
|
obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load
|
||||||
|
# Add validation checks here
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
|
||||||
|
buffer = io.BytesIO(buffer)
|
||||||
|
buffer.seek(0)
|
||||||
|
transitions = torch.load(buffer) # nosec B614: Safe usage of torch.load
|
||||||
|
# Add validation checks here
|
||||||
|
return transitions
|
||||||
|
|
||||||
|
|
||||||
|
def transitions_to_bytes(transitions: list[Transition]) -> bytes:
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
torch.save(transitions, buffer)
|
||||||
|
return buffer.getvalue()
|
||||||
|
|||||||
@@ -19,7 +19,9 @@ import logging
|
|||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
from queue import Empty
|
from queue import Empty
|
||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch.multiprocessing import Queue
|
from torch.multiprocessing import Queue
|
||||||
|
|
||||||
shutdown_event_counter = 0
|
shutdown_event_counter = 0
|
||||||
@@ -71,3 +73,69 @@ def get_last_item_from_queue(queue: Queue):
|
|||||||
logging.debug(f"Drained {counter} items from queue")
|
logging.debug(f"Drained {counter} items from queue")
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
|
||||||
|
|
||||||
|
class Transition(TypedDict):
|
||||||
|
state: dict[str, torch.Tensor]
|
||||||
|
action: torch.Tensor
|
||||||
|
reward: float
|
||||||
|
next_state: dict[str, torch.Tensor]
|
||||||
|
done: bool
|
||||||
|
truncated: bool
|
||||||
|
complementary_info: dict[str, torch.Tensor | float | int] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
|
||||||
|
device = torch.device(device)
|
||||||
|
non_blocking = device.type == "cuda"
|
||||||
|
|
||||||
|
# Move state tensors to device
|
||||||
|
transition["state"] = {
|
||||||
|
key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Move action to device
|
||||||
|
transition["action"] = transition["action"].to(device, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
# Move reward and done if they are tensors
|
||||||
|
if isinstance(transition["reward"], torch.Tensor):
|
||||||
|
transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
if isinstance(transition["done"], torch.Tensor):
|
||||||
|
transition["done"] = transition["done"].to(device, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
if isinstance(transition["truncated"], torch.Tensor):
|
||||||
|
transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
# Move next_state tensors to device
|
||||||
|
transition["next_state"] = {
|
||||||
|
key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Move complementary_info tensors if present
|
||||||
|
if transition.get("complementary_info") is not None:
|
||||||
|
for key, val in transition["complementary_info"].items():
|
||||||
|
if isinstance(val, torch.Tensor):
|
||||||
|
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
|
||||||
|
elif isinstance(val, (int, float, bool)):
|
||||||
|
transition["complementary_info"][key] = torch.tensor(val, device=device)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
|
||||||
|
return transition
|
||||||
|
|
||||||
|
|
||||||
|
def move_state_dict_to_device(state_dict, device="cpu"):
|
||||||
|
"""
|
||||||
|
Recursively move all tensors in a (potentially) nested
|
||||||
|
dict/list/tuple structure to the CPU.
|
||||||
|
"""
|
||||||
|
if isinstance(state_dict, torch.Tensor):
|
||||||
|
return state_dict.to(device)
|
||||||
|
elif isinstance(state_dict, dict):
|
||||||
|
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
|
||||||
|
elif isinstance(state_dict, list):
|
||||||
|
return [move_state_dict_to_device(v, device=device) for v in state_dict]
|
||||||
|
elif isinstance(state_dict, tuple):
|
||||||
|
return tuple(move_state_dict_to_device(v, device=device) for v in state_dict)
|
||||||
|
else:
|
||||||
|
return state_dict
|
||||||
|
|||||||
Reference in New Issue
Block a user