2172 lines
76 KiB
Python
2172 lines
76 KiB
Python
import logging
|
|
import sys
|
|
import time
|
|
from collections import deque
|
|
from threading import Lock
|
|
from typing import Annotated, Any, Dict, Sequence, Tuple
|
|
|
|
import gymnasium as gym
|
|
import numpy as np
|
|
import torch
|
|
import torchvision.transforms.functional as F # noqa: N812
|
|
|
|
from lerobot.common.envs.configs import EnvConfig
|
|
from lerobot.common.envs.utils import preprocess_observation
|
|
from lerobot.common.robot_devices.control_utils import (
|
|
busy_wait,
|
|
is_headless,
|
|
reset_follower_position,
|
|
)
|
|
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
|
from lerobot.common.utils.utils import log_say
|
|
from lerobot.configs import parser
|
|
from lerobot.scripts.server.kinematics import RobotKinematics
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
MAX_GRIPPER_COMMAND = 30
|
|
|
|
|
|
class TorchBox(gym.spaces.Box):
|
|
"""
|
|
A version of gym.spaces.Box that handles PyTorch tensors.
|
|
|
|
This class extends gym.spaces.Box to work with PyTorch tensors,
|
|
providing compatibility between NumPy arrays and PyTorch tensors.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
low: float | Sequence[float] | np.ndarray,
|
|
high: float | Sequence[float] | np.ndarray,
|
|
shape: Sequence[int] | None = None,
|
|
np_dtype: np.dtype | type = np.float32,
|
|
torch_dtype: torch.dtype = torch.float32,
|
|
device: str = "cpu",
|
|
seed: int | np.random.Generator | None = None,
|
|
) -> None:
|
|
"""
|
|
Initialize the PyTorch-compatible Box space.
|
|
|
|
Args:
|
|
low: Lower bounds of the space.
|
|
high: Upper bounds of the space.
|
|
shape: Shape of the space. If None, inferred from low and high.
|
|
np_dtype: NumPy data type for internal storage.
|
|
torch_dtype: PyTorch data type for tensor conversion.
|
|
device: PyTorch device for returned tensors.
|
|
seed: Random seed for sampling.
|
|
"""
|
|
super().__init__(low, high, shape=shape, dtype=np_dtype, seed=seed)
|
|
self.torch_dtype = torch_dtype
|
|
self.device = device
|
|
|
|
def sample(self) -> torch.Tensor:
|
|
"""
|
|
Sample a random point from the space.
|
|
|
|
Returns:
|
|
A PyTorch tensor within the space bounds.
|
|
"""
|
|
arr = super().sample()
|
|
return torch.as_tensor(arr, dtype=self.torch_dtype, device=self.device)
|
|
|
|
def contains(self, x: torch.Tensor) -> bool:
|
|
"""
|
|
Check if a tensor is within the space bounds.
|
|
|
|
Args:
|
|
x: The PyTorch tensor to check.
|
|
|
|
Returns:
|
|
Boolean indicating whether the tensor is within bounds.
|
|
"""
|
|
# Move to CPU/numpy and cast to the internal dtype
|
|
arr = x.detach().cpu().numpy().astype(self.dtype, copy=False)
|
|
return super().contains(arr)
|
|
|
|
def seed(self, seed: int | np.random.Generator | None = None):
|
|
"""
|
|
Set the random seed for sampling.
|
|
|
|
Args:
|
|
seed: The random seed to use.
|
|
|
|
Returns:
|
|
List containing the seed.
|
|
"""
|
|
super().seed(seed)
|
|
return [seed]
|
|
|
|
def __repr__(self) -> str:
|
|
"""
|
|
Return a string representation of the space.
|
|
|
|
Returns:
|
|
Formatted string with space details.
|
|
"""
|
|
return (
|
|
f"TorchBox({self.low_repr}, {self.high_repr}, {self.shape}, "
|
|
f"np={self.dtype.name}, torch={self.torch_dtype}, device={self.device})"
|
|
)
|
|
|
|
|
|
class TorchActionWrapper(gym.Wrapper):
|
|
"""
|
|
Wrapper that changes the action space to use PyTorch tensors.
|
|
|
|
This wrapper modifies the action space to return PyTorch tensors when sampled
|
|
and handles converting PyTorch actions to NumPy when stepping the environment.
|
|
"""
|
|
|
|
def __init__(self, env: gym.Env, device: str):
|
|
"""
|
|
Initialize the PyTorch action space wrapper.
|
|
|
|
Args:
|
|
env: The environment to wrap.
|
|
device: The PyTorch device to use for tensor operations.
|
|
"""
|
|
super().__init__(env)
|
|
self.action_space = TorchBox(
|
|
low=env.action_space.low,
|
|
high=env.action_space.high,
|
|
shape=env.action_space.shape,
|
|
torch_dtype=torch.float32,
|
|
device=torch.device("cpu"),
|
|
)
|
|
|
|
def step(self, action: torch.Tensor):
|
|
"""
|
|
Step the environment with a PyTorch tensor action.
|
|
|
|
This method handles conversion from PyTorch tensors to NumPy arrays
|
|
for compatibility with the underlying environment.
|
|
|
|
Args:
|
|
action: PyTorch tensor action to take.
|
|
|
|
Returns:
|
|
Tuple of (observation, reward, terminated, truncated, info).
|
|
"""
|
|
if action.dim() == 2:
|
|
action = action.squeeze(0)
|
|
action = action.detach().cpu().numpy()
|
|
return self.env.step(action)
|
|
|
|
|
|
class RobotEnv(gym.Env):
|
|
"""
|
|
Gym-compatible environment for evaluating robotic control policies with integrated human intervention.
|
|
|
|
This environment wraps a robot interface to provide a consistent API for policy evaluation. It supports both relative (delta)
|
|
and absolute joint position commands and automatically configures its observation and action spaces based on the robot's
|
|
sensors and configuration.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
robot,
|
|
display_cameras: bool = False,
|
|
):
|
|
"""
|
|
Initialize the RobotEnv environment.
|
|
|
|
The environment is set up with a robot interface, which is used to capture observations and send joint commands. The setup
|
|
supports both relative (delta) adjustments and absolute joint positions for controlling the robot.
|
|
|
|
Args:
|
|
robot: The robot interface object used to connect and interact with the physical robot.
|
|
display_cameras: If True, the robot's camera feeds will be displayed during execution.
|
|
"""
|
|
super().__init__()
|
|
|
|
self.robot = robot
|
|
self.display_cameras = display_cameras
|
|
|
|
# Connect to the robot if not already connected.
|
|
if not self.robot.is_connected:
|
|
self.robot.connect()
|
|
|
|
# Episode tracking.
|
|
self.current_step = 0
|
|
self.episode_data = None
|
|
|
|
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
|
|
|
|
self._setup_spaces()
|
|
|
|
def _setup_spaces(self):
|
|
"""
|
|
Dynamically configure the observation and action spaces based on the robot's capabilities.
|
|
|
|
Observation Space:
|
|
- For keys with "image": A Box space with pixel values ranging from 0 to 255.
|
|
- For non-image keys: A nested Dict space is created under 'observation.state' with a suitable range.
|
|
|
|
Action Space:
|
|
- The action space is defined as a Box space representing joint position commands. It is defined as relative (delta)
|
|
or absolute, based on the configuration.
|
|
"""
|
|
example_obs = self.robot.capture_observation()
|
|
|
|
# Define observation spaces for images and other states.
|
|
image_keys = [key for key in example_obs if "image" in key]
|
|
observation_spaces = {
|
|
key: gym.spaces.Box(low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8)
|
|
for key in image_keys
|
|
}
|
|
observation_spaces["observation.state"] = gym.spaces.Box(
|
|
low=0,
|
|
high=10,
|
|
shape=example_obs["observation.state"].shape,
|
|
dtype=np.float32,
|
|
)
|
|
|
|
self.observation_space = gym.spaces.Dict(observation_spaces)
|
|
|
|
# Define the action space for joint positions along with setting an intervention flag.
|
|
action_dim = len(self.robot.follower_arms["main"].read("Present_Position"))
|
|
bounds = {}
|
|
bounds["min"] = np.ones(action_dim) * -1000
|
|
bounds["max"] = np.ones(action_dim) * 1000
|
|
|
|
self.action_space = gym.spaces.Box(
|
|
low=bounds["min"],
|
|
high=bounds["max"],
|
|
shape=(action_dim,),
|
|
dtype=np.float32,
|
|
)
|
|
|
|
def reset(self, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
|
|
"""
|
|
Reset the environment to its initial state.
|
|
This method resets the step counter and clears any episodic data.
|
|
|
|
Args:
|
|
seed: A seed for random number generation to ensure reproducibility.
|
|
options: Additional options to influence the reset behavior.
|
|
|
|
Returns:
|
|
A tuple containing:
|
|
- observation (dict): The initial sensor observation.
|
|
- info (dict): A dictionary with supplementary information, including the key "is_intervention".
|
|
"""
|
|
super().reset(seed=seed, options=options)
|
|
|
|
# Capture the initial observation.
|
|
observation = self.robot.capture_observation()
|
|
|
|
# Reset episode tracking variables.
|
|
self.current_step = 0
|
|
self.episode_data = None
|
|
|
|
return observation, {"is_intervention": False}
|
|
|
|
def step(self, action) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]:
|
|
"""
|
|
Execute a single step within the environment using the specified action.
|
|
|
|
The provided action is processed and sent to the robot as joint position commands
|
|
that may be either absolute values or deltas based on the environment configuration.
|
|
|
|
Args:
|
|
action: The commanded joint positions as a numpy array or torch tensor.
|
|
|
|
Returns:
|
|
A tuple containing:
|
|
- observation (dict): The new sensor observation after taking the step.
|
|
- reward (float): The step reward (default is 0.0 within this wrapper).
|
|
- terminated (bool): True if the episode has reached a terminal state.
|
|
- truncated (bool): True if the episode was truncated (e.g., time constraints).
|
|
- info (dict): Additional debugging information including intervention status.
|
|
"""
|
|
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
|
|
|
|
self.robot.send_action(torch.from_numpy(action))
|
|
observation = self.robot.capture_observation()
|
|
|
|
if self.display_cameras:
|
|
self.render()
|
|
|
|
self.current_step += 1
|
|
|
|
reward = 0.0
|
|
terminated = False
|
|
truncated = False
|
|
|
|
return (
|
|
observation,
|
|
reward,
|
|
terminated,
|
|
truncated,
|
|
{"is_intervention": False},
|
|
)
|
|
|
|
def render(self):
|
|
"""
|
|
Render the current state of the environment by displaying the robot's camera feeds.
|
|
"""
|
|
import cv2
|
|
|
|
observation = self.robot.capture_observation()
|
|
image_keys = [key for key in observation if "image" in key]
|
|
|
|
for key in image_keys:
|
|
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
|
cv2.waitKey(1)
|
|
|
|
def close(self):
|
|
"""
|
|
Close the environment and clean up resources by disconnecting the robot.
|
|
|
|
If the robot is currently connected, this method properly terminates the connection to ensure that all
|
|
associated resources are released.
|
|
"""
|
|
if self.robot.is_connected:
|
|
self.robot.disconnect()
|
|
|
|
|
|
class AddJointVelocityToObservation(gym.ObservationWrapper):
|
|
"""
|
|
Wrapper that adds joint velocity information to the observation.
|
|
|
|
This wrapper computes joint velocities by tracking changes in joint positions over time,
|
|
and extends the observation space to include these velocities.
|
|
"""
|
|
|
|
def __init__(self, env, joint_velocity_limits=100.0, fps=30, num_dof=6):
|
|
"""
|
|
Initialize the joint velocity wrapper.
|
|
|
|
Args:
|
|
env: The environment to wrap.
|
|
joint_velocity_limits: Maximum expected joint velocity for space bounds.
|
|
fps: Frames per second used to calculate velocity (position delta / time).
|
|
num_dof: Number of degrees of freedom (joints) in the robot.
|
|
"""
|
|
super().__init__(env)
|
|
|
|
# Extend observation space to include joint velocities
|
|
old_low = self.observation_space["observation.state"].low
|
|
old_high = self.observation_space["observation.state"].high
|
|
old_shape = self.observation_space["observation.state"].shape
|
|
|
|
self.last_joint_positions = np.zeros(num_dof)
|
|
|
|
new_low = np.concatenate([old_low, np.ones(num_dof) * -joint_velocity_limits])
|
|
new_high = np.concatenate([old_high, np.ones(num_dof) * joint_velocity_limits])
|
|
|
|
new_shape = (old_shape[0] + num_dof,)
|
|
|
|
self.observation_space["observation.state"] = gym.spaces.Box(
|
|
low=new_low,
|
|
high=new_high,
|
|
shape=new_shape,
|
|
dtype=np.float32,
|
|
)
|
|
|
|
self.dt = 1.0 / fps
|
|
|
|
def observation(self, observation):
|
|
"""
|
|
Add joint velocity information to the observation.
|
|
|
|
Args:
|
|
observation: The original observation from the environment.
|
|
|
|
Returns:
|
|
The modified observation with joint velocities.
|
|
"""
|
|
joint_velocities = (observation["observation.state"] - self.last_joint_positions) / self.dt
|
|
self.last_joint_positions = observation["observation.state"].clone()
|
|
observation["observation.state"] = torch.cat(
|
|
[observation["observation.state"], joint_velocities], dim=-1
|
|
)
|
|
return observation
|
|
|
|
|
|
class AddCurrentToObservation(gym.ObservationWrapper):
|
|
"""
|
|
Wrapper that adds motor current information to the observation.
|
|
|
|
This wrapper extends the observation space to include the current values
|
|
from each motor, providing information about the forces being applied.
|
|
"""
|
|
|
|
def __init__(self, env, max_current=500, num_dof=6):
|
|
"""
|
|
Initialize the current observation wrapper.
|
|
|
|
Args:
|
|
env: The environment to wrap.
|
|
max_current: Maximum expected current for space bounds.
|
|
num_dof: Number of degrees of freedom (joints) in the robot.
|
|
"""
|
|
super().__init__(env)
|
|
|
|
# Extend observation space to include joint velocities
|
|
old_low = self.observation_space["observation.state"].low
|
|
old_high = self.observation_space["observation.state"].high
|
|
old_shape = self.observation_space["observation.state"].shape
|
|
|
|
new_low = np.concatenate([old_low, np.zeros(num_dof)])
|
|
new_high = np.concatenate([old_high, np.ones(num_dof) * max_current])
|
|
|
|
new_shape = (old_shape[0] + num_dof,)
|
|
|
|
self.observation_space["observation.state"] = gym.spaces.Box(
|
|
low=new_low,
|
|
high=new_high,
|
|
shape=new_shape,
|
|
dtype=np.float32,
|
|
)
|
|
|
|
def observation(self, observation):
|
|
"""
|
|
Add current information to the observation.
|
|
|
|
Args:
|
|
observation: The original observation from the environment.
|
|
|
|
Returns:
|
|
The modified observation with current values.
|
|
"""
|
|
present_current = (
|
|
self.unwrapped.robot.follower_arms["main"].read("Present_Current").astype(np.float32)
|
|
)
|
|
observation["observation.state"] = torch.cat(
|
|
[observation["observation.state"], torch.from_numpy(present_current)], dim=-1
|
|
)
|
|
return observation
|
|
|
|
|
|
class RewardWrapper(gym.Wrapper):
|
|
def __init__(self, env, reward_classifier, device="cuda"):
|
|
"""
|
|
Wrapper to add reward prediction to the environment using a trained classifier.
|
|
|
|
Args:
|
|
env: The environment to wrap.
|
|
reward_classifier: The reward classifier model.
|
|
device: The device to run the model on.
|
|
"""
|
|
self.env = env
|
|
|
|
self.device = device
|
|
|
|
self.reward_classifier = torch.compile(reward_classifier)
|
|
self.reward_classifier.to(self.device)
|
|
|
|
def step(self, action):
|
|
"""
|
|
Execute a step and compute the reward using the classifier.
|
|
|
|
Args:
|
|
action: The action to take in the environment.
|
|
|
|
Returns:
|
|
Tuple of (observation, reward, terminated, truncated, info).
|
|
"""
|
|
observation, _, terminated, truncated, info = self.env.step(action)
|
|
|
|
images = {}
|
|
for key in observation:
|
|
if "image" in key:
|
|
images[key] = observation[key].to(self.device, non_blocking=(self.device == "cuda"))
|
|
if images[key].dim() == 3:
|
|
images[key] = images[key].unsqueeze(0)
|
|
|
|
start_time = time.perf_counter()
|
|
with torch.inference_mode():
|
|
success = (
|
|
self.reward_classifier.predict_reward(images, threshold=0.7)
|
|
if self.reward_classifier is not None
|
|
else 0.0
|
|
)
|
|
info["Reward classifier frequency"] = 1 / (time.perf_counter() - start_time)
|
|
|
|
reward = 0.0
|
|
if success == 1.0:
|
|
terminated = True
|
|
reward = 1.0
|
|
|
|
return observation, reward, terminated, truncated, info
|
|
|
|
def reset(self, seed=None, options=None):
|
|
"""
|
|
Reset the environment.
|
|
|
|
Args:
|
|
seed: Random seed for reproducibility.
|
|
options: Additional reset options.
|
|
|
|
Returns:
|
|
The initial observation and info from the wrapped environment.
|
|
"""
|
|
return self.env.reset(seed=seed, options=options)
|
|
|
|
|
|
class TimeLimitWrapper(gym.Wrapper):
|
|
"""
|
|
Wrapper that adds a time limit to episodes and tracks execution time.
|
|
|
|
This wrapper terminates episodes after a specified time has elapsed, providing
|
|
better control over episode length.
|
|
"""
|
|
|
|
def __init__(self, env, control_time_s, fps):
|
|
"""
|
|
Initialize the time limit wrapper.
|
|
|
|
Args:
|
|
env: The environment to wrap.
|
|
control_time_s: Maximum episode duration in seconds.
|
|
fps: Frames per second for calculating the maximum number of steps.
|
|
"""
|
|
self.env = env
|
|
self.control_time_s = control_time_s
|
|
self.fps = fps
|
|
|
|
self.last_timestamp = 0.0
|
|
self.episode_time_in_s = 0.0
|
|
|
|
self.max_episode_steps = int(self.control_time_s * self.fps)
|
|
|
|
self.current_step = 0
|
|
|
|
def step(self, action):
|
|
"""
|
|
Step the environment and track time elapsed.
|
|
|
|
Args:
|
|
action: The action to take in the environment.
|
|
|
|
Returns:
|
|
Tuple of (observation, reward, terminated, truncated, info).
|
|
"""
|
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
time_since_last_step = time.perf_counter() - self.last_timestamp
|
|
self.episode_time_in_s += time_since_last_step
|
|
self.last_timestamp = time.perf_counter()
|
|
self.current_step += 1
|
|
# check if last timestep took more time than the expected fps
|
|
if 1.0 / time_since_last_step < self.fps:
|
|
logging.debug(f"Current timestep exceeded expected fps {self.fps}")
|
|
|
|
if self.current_step >= self.max_episode_steps:
|
|
terminated = True
|
|
return obs, reward, terminated, truncated, info
|
|
|
|
def reset(self, seed=None, options=None):
|
|
"""
|
|
Reset the environment and time tracking.
|
|
|
|
Args:
|
|
seed: Random seed for reproducibility.
|
|
options: Additional reset options.
|
|
|
|
Returns:
|
|
The initial observation and info from the wrapped environment.
|
|
"""
|
|
self.episode_time_in_s = 0.0
|
|
self.last_timestamp = time.perf_counter()
|
|
self.current_step = 0
|
|
return self.env.reset(seed=seed, options=options)
|
|
|
|
|
|
class ImageCropResizeWrapper(gym.Wrapper):
|
|
"""
|
|
Wrapper that crops and resizes image observations.
|
|
|
|
This wrapper processes image observations to focus on relevant regions by
|
|
cropping and then resizing to a standard size.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
env,
|
|
crop_params_dict: Dict[str, Annotated[Tuple[int], 4]],
|
|
resize_size=None,
|
|
):
|
|
"""
|
|
Initialize the image crop and resize wrapper.
|
|
|
|
Args:
|
|
env: The environment to wrap.
|
|
crop_params_dict: Dictionary mapping image observation keys to crop parameters
|
|
(top, left, height, width).
|
|
resize_size: Target size for resized images (height, width). Defaults to (128, 128).
|
|
"""
|
|
super().__init__(env)
|
|
self.env = env
|
|
self.crop_params_dict = crop_params_dict
|
|
print(f"obs_keys , {self.env.observation_space}")
|
|
print(f"crop params dict {crop_params_dict.keys()}")
|
|
for key_crop in crop_params_dict:
|
|
if key_crop not in self.env.observation_space.keys(): # noqa: SIM118
|
|
raise ValueError(f"Key {key_crop} not in observation space")
|
|
for key in crop_params_dict:
|
|
new_shape = (3, resize_size[0], resize_size[1])
|
|
self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape)
|
|
|
|
self.resize_size = resize_size
|
|
if self.resize_size is None:
|
|
self.resize_size = (128, 128)
|
|
|
|
def step(self, action):
|
|
"""
|
|
Step the environment and process image observations.
|
|
|
|
Args:
|
|
action: The action to take in the environment.
|
|
|
|
Returns:
|
|
Tuple of (observation, reward, terminated, truncated, info) with processed images.
|
|
"""
|
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
for k in self.crop_params_dict:
|
|
device = obs[k].device
|
|
if obs[k].dim() >= 3:
|
|
# Reshape to combine height and width dimensions for easier calculation
|
|
batch_size = obs[k].size(0)
|
|
channels = obs[k].size(1)
|
|
flattened_spatial_dims = obs[k].view(batch_size, channels, -1)
|
|
|
|
# Calculate standard deviation across spatial dimensions (H, W)
|
|
# If any channel has std=0, all pixels in that channel have the same value
|
|
# This is helpful if one camera mistakenly covered or the image is black
|
|
std_per_channel = torch.std(flattened_spatial_dims, dim=2)
|
|
if (std_per_channel <= 0.02).any():
|
|
logging.warning(
|
|
f"Potential hardware issue detected: All pixels have the same value in observation {k}"
|
|
)
|
|
|
|
if device == torch.device("mps:0"):
|
|
obs[k] = obs[k].cpu()
|
|
|
|
obs[k] = F.crop(obs[k], *self.crop_params_dict[k])
|
|
obs[k] = F.resize(obs[k], self.resize_size)
|
|
# TODO (michel-aractingi): Bug in resize, it returns values outside [0, 1]
|
|
obs[k] = obs[k].clamp(0.0, 1.0)
|
|
obs[k] = obs[k].to(device)
|
|
|
|
return obs, reward, terminated, truncated, info
|
|
|
|
def reset(self, seed=None, options=None):
|
|
"""
|
|
Reset the environment and process image observations.
|
|
|
|
Args:
|
|
seed: Random seed for reproducibility.
|
|
options: Additional reset options.
|
|
|
|
Returns:
|
|
Tuple of (observation, info) with processed images.
|
|
"""
|
|
obs, info = self.env.reset(seed=seed, options=options)
|
|
for k in self.crop_params_dict:
|
|
device = obs[k].device
|
|
if device == torch.device("mps:0"):
|
|
obs[k] = obs[k].cpu()
|
|
obs[k] = F.crop(obs[k], *self.crop_params_dict[k])
|
|
obs[k] = F.resize(obs[k], self.resize_size)
|
|
obs[k] = obs[k].clamp(0.0, 1.0)
|
|
obs[k] = obs[k].to(device)
|
|
return obs, info
|
|
|
|
|
|
class ConvertToLeRobotObservation(gym.ObservationWrapper):
|
|
"""
|
|
Wrapper that converts standard observations to LeRobot format.
|
|
|
|
This wrapper processes observations to match the expected format for LeRobot,
|
|
including normalizing image values and moving tensors to the specified device.
|
|
"""
|
|
|
|
def __init__(self, env, device: str = "cpu"):
|
|
"""
|
|
Initialize the LeRobot observation converter.
|
|
|
|
Args:
|
|
env: The environment to wrap.
|
|
device: Target device for the observation tensors.
|
|
"""
|
|
super().__init__(env)
|
|
|
|
self.device = torch.device(device)
|
|
|
|
def observation(self, observation):
|
|
"""
|
|
Convert observations to LeRobot format.
|
|
|
|
Args:
|
|
observation: The original observation from the environment.
|
|
|
|
Returns:
|
|
The processed observation with normalized images and proper tensor formats.
|
|
"""
|
|
for key in observation:
|
|
observation[key] = observation[key].float()
|
|
if "image" in key:
|
|
observation[key] = observation[key].permute(2, 0, 1)
|
|
observation[key] /= 255.0
|
|
observation = {
|
|
key: observation[key].to(self.device, non_blocking=self.device.type == "cuda")
|
|
for key in observation
|
|
}
|
|
|
|
return observation
|
|
|
|
|
|
class ResetWrapper(gym.Wrapper):
|
|
"""
|
|
Wrapper that handles environment reset procedures.
|
|
|
|
This wrapper provides additional functionality during environment reset,
|
|
including the option to reset to a fixed pose or allow manual reset.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
env: RobotEnv,
|
|
reset_pose: np.ndarray | None = None,
|
|
reset_time_s: float = 5,
|
|
):
|
|
"""
|
|
Initialize the reset wrapper.
|
|
|
|
Args:
|
|
env: The environment to wrap.
|
|
reset_pose: Fixed joint positions to reset to. If None, manual reset is used.
|
|
reset_time_s: Time in seconds to wait after reset or allowed for manual reset.
|
|
"""
|
|
super().__init__(env)
|
|
self.reset_time_s = reset_time_s
|
|
self.reset_pose = reset_pose
|
|
self.robot = self.unwrapped.robot
|
|
|
|
def reset(self, *, seed=None, options=None):
|
|
"""
|
|
Reset the environment with either fixed or manual reset procedure.
|
|
|
|
If reset_pose is provided, the robot will move to that position.
|
|
Otherwise, manual teleoperation control is allowed for reset_time_s seconds.
|
|
|
|
Args:
|
|
seed: Random seed for reproducibility.
|
|
options: Additional reset options.
|
|
|
|
Returns:
|
|
The initial observation and info from the wrapped environment.
|
|
"""
|
|
start_time = time.perf_counter()
|
|
if self.reset_pose is not None:
|
|
log_say("Reset the environment.", play_sounds=True)
|
|
reset_follower_position(self.robot.follower_arms["main"], self.reset_pose)
|
|
log_say("Reset the environment done.", play_sounds=True)
|
|
|
|
if len(self.robot.leader_arms) > 0:
|
|
self.robot.leader_arms["main"].write("Torque_Enable", 1)
|
|
log_say("Reset the leader robot.", play_sounds=True)
|
|
reset_follower_position(self.robot.leader_arms["main"], self.reset_pose)
|
|
log_say("Reset the leader robot done.", play_sounds=True)
|
|
else:
|
|
log_say(
|
|
f"Manually reset the environment for {self.reset_time_s} seconds.",
|
|
play_sounds=True,
|
|
)
|
|
start_time = time.perf_counter()
|
|
while time.perf_counter() - start_time < self.reset_time_s:
|
|
self.robot.teleop_step()
|
|
|
|
log_say("Manual reset of the environment done.", play_sounds=True)
|
|
|
|
busy_wait(self.reset_time_s - (time.perf_counter() - start_time))
|
|
|
|
return super().reset(seed=seed, options=options)
|
|
|
|
|
|
class BatchCompatibleWrapper(gym.ObservationWrapper):
|
|
"""
|
|
Wrapper that ensures observations are compatible with batch processing.
|
|
|
|
This wrapper adds a batch dimension to observations that don't already have one,
|
|
making them compatible with models that expect batched inputs.
|
|
"""
|
|
|
|
def __init__(self, env):
|
|
"""
|
|
Initialize the batch compatibility wrapper.
|
|
|
|
Args:
|
|
env: The environment to wrap.
|
|
"""
|
|
super().__init__(env)
|
|
|
|
def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
"""
|
|
Add batch dimensions to observations if needed.
|
|
|
|
Args:
|
|
observation: Dictionary of observation tensors.
|
|
|
|
Returns:
|
|
Dictionary of observation tensors with batch dimensions.
|
|
"""
|
|
for key in observation:
|
|
if "image" in key and observation[key].dim() == 3:
|
|
observation[key] = observation[key].unsqueeze(0)
|
|
if "state" in key and observation[key].dim() == 1:
|
|
observation[key] = observation[key].unsqueeze(0)
|
|
if "velocity" in key and observation[key].dim() == 1:
|
|
observation[key] = observation[key].unsqueeze(0)
|
|
return observation
|
|
|
|
|
|
class GripperPenaltyWrapper(gym.RewardWrapper):
|
|
"""
|
|
Wrapper that adds penalties for inefficient gripper commands.
|
|
|
|
This wrapper modifies rewards to discourage excessive gripper movement
|
|
or commands that attempt to move the gripper beyond its physical limits.
|
|
"""
|
|
|
|
def __init__(self, env, penalty: float = -0.1):
|
|
"""
|
|
Initialize the gripper penalty wrapper.
|
|
|
|
Args:
|
|
env: The environment to wrap.
|
|
penalty: Negative reward value to apply for inefficient gripper actions.
|
|
"""
|
|
super().__init__(env)
|
|
self.penalty = penalty
|
|
self.last_gripper_state = None
|
|
|
|
def reward(self, reward, action):
|
|
"""
|
|
Apply penalties to reward based on gripper actions.
|
|
|
|
Args:
|
|
reward: The original reward from the environment.
|
|
action: The action that was taken.
|
|
|
|
Returns:
|
|
Modified reward with penalty applied if necessary.
|
|
"""
|
|
gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND
|
|
|
|
action_normalized = action - 1.0 # action / MAX_GRIPPER_COMMAND
|
|
|
|
gripper_penalty_bool = (gripper_state_normalized < 0.5 and action_normalized > 0.5) or (
|
|
gripper_state_normalized > 0.75 and action_normalized < -0.5
|
|
)
|
|
|
|
return reward + self.penalty * int(gripper_penalty_bool)
|
|
|
|
def step(self, action):
|
|
"""
|
|
Step the environment and apply gripper penalties.
|
|
|
|
Args:
|
|
action: The action to take in the environment.
|
|
|
|
Returns:
|
|
Tuple of (observation, reward, terminated, truncated, info) with penalty applied.
|
|
"""
|
|
self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
|
|
gripper_action = action[-1]
|
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
gripper_penalty = self.reward(reward, gripper_action)
|
|
|
|
info["discrete_penalty"] = gripper_penalty
|
|
|
|
return obs, reward, terminated, truncated, info
|
|
|
|
def reset(self, **kwargs):
|
|
"""
|
|
Reset the environment and penalty tracking.
|
|
|
|
Args:
|
|
**kwargs: Keyword arguments passed to the wrapped environment's reset.
|
|
|
|
Returns:
|
|
The initial observation and info with gripper penalty initialized.
|
|
"""
|
|
self.last_gripper_state = None
|
|
obs, info = super().reset(**kwargs)
|
|
info["gripper_penalty"] = 0.0
|
|
return obs, info
|
|
|
|
|
|
class GripperActionWrapper(gym.ActionWrapper):
|
|
"""
|
|
Wrapper that processes gripper control commands.
|
|
|
|
This wrapper quantizes and processes gripper commands, adding a sleep time between
|
|
consecutive gripper actions to prevent rapid toggling.
|
|
"""
|
|
|
|
def __init__(self, env, quantization_threshold: float = 0.2, gripper_sleep: float = 0.0):
|
|
"""
|
|
Initialize the gripper action wrapper.
|
|
|
|
Args:
|
|
env: The environment to wrap.
|
|
quantization_threshold: Threshold below which gripper commands are quantized to zero.
|
|
gripper_sleep: Minimum time in seconds between consecutive gripper commands.
|
|
"""
|
|
super().__init__(env)
|
|
self.quantization_threshold = quantization_threshold
|
|
self.gripper_sleep = gripper_sleep
|
|
self.last_gripper_action_time = 0.0
|
|
self.last_gripper_action = None
|
|
|
|
def action(self, action):
|
|
"""
|
|
Process gripper commands in the action.
|
|
|
|
Args:
|
|
action: The original action from the agent.
|
|
|
|
Returns:
|
|
Modified action with processed gripper command.
|
|
"""
|
|
if self.gripper_sleep > 0.0:
|
|
if (
|
|
self.last_gripper_action is not None
|
|
and time.perf_counter() - self.last_gripper_action_time < self.gripper_sleep
|
|
):
|
|
action[-1] = self.last_gripper_action
|
|
else:
|
|
self.last_gripper_action_time = time.perf_counter()
|
|
self.last_gripper_action = action[-1]
|
|
|
|
gripper_command = action[-1]
|
|
# Gripper actions are between 0, 2
|
|
# we want to quantize them to -1, 0 or 1
|
|
gripper_command = gripper_command - 1.0
|
|
|
|
if self.quantization_threshold is not None:
|
|
# Quantize gripper command to -1, 0 or 1
|
|
gripper_command = (
|
|
np.sign(gripper_command) if abs(gripper_command) > self.quantization_threshold else 0.0
|
|
)
|
|
gripper_command = gripper_command * MAX_GRIPPER_COMMAND
|
|
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
|
|
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
|
|
action[-1] = gripper_action.item()
|
|
return action
|
|
|
|
def reset(self, **kwargs):
|
|
"""
|
|
Reset the gripper action tracking.
|
|
|
|
Args:
|
|
**kwargs: Keyword arguments passed to the wrapped environment's reset.
|
|
|
|
Returns:
|
|
The initial observation and info.
|
|
"""
|
|
obs, info = super().reset(**kwargs)
|
|
self.last_gripper_action_time = 0.0
|
|
self.last_gripper_action = None
|
|
return obs, info
|
|
|
|
|
|
class EEActionWrapper(gym.ActionWrapper):
|
|
"""
|
|
Wrapper that converts end-effector space actions to joint space actions.
|
|
|
|
This wrapper takes actions defined in cartesian space (x, y, z, gripper) and
|
|
converts them to joint space actions using inverse kinematics.
|
|
"""
|
|
|
|
def __init__(self, env, ee_action_space_params=None, use_gripper=False):
|
|
"""
|
|
Initialize the end-effector action wrapper.
|
|
|
|
Args:
|
|
env: The environment to wrap.
|
|
ee_action_space_params: Parameters defining the end-effector action space.
|
|
use_gripper: Whether to include gripper control in the action space.
|
|
"""
|
|
super().__init__(env)
|
|
self.ee_action_space_params = ee_action_space_params
|
|
self.use_gripper = use_gripper
|
|
|
|
# Initialize kinematics instance for the appropriate robot type
|
|
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100")
|
|
self.kinematics = RobotKinematics(robot_type)
|
|
self.fk_function = self.kinematics.fk_gripper_tip
|
|
|
|
action_space_bounds = np.array(
|
|
[
|
|
ee_action_space_params.x_step_size,
|
|
ee_action_space_params.y_step_size,
|
|
ee_action_space_params.z_step_size,
|
|
]
|
|
)
|
|
if self.use_gripper:
|
|
# gripper actions open at 2.0, and closed at 0.0
|
|
min_action_space_bounds = np.concatenate([-action_space_bounds, [0.0]])
|
|
max_action_space_bounds = np.concatenate([action_space_bounds, [2.0]])
|
|
else:
|
|
min_action_space_bounds = -action_space_bounds
|
|
max_action_space_bounds = action_space_bounds
|
|
|
|
self.action_space = gym.spaces.Box(
|
|
low=min_action_space_bounds,
|
|
high=max_action_space_bounds,
|
|
shape=(3 + int(self.use_gripper),),
|
|
dtype=np.float32,
|
|
)
|
|
|
|
self.bounds = ee_action_space_params.bounds
|
|
|
|
def action(self, action):
|
|
"""
|
|
Convert end-effector action to joint space action.
|
|
|
|
Args:
|
|
action: End-effector action in cartesian space.
|
|
|
|
Returns:
|
|
Converted action in joint space.
|
|
"""
|
|
desired_ee_pos = np.eye(4)
|
|
|
|
if self.use_gripper:
|
|
gripper_command = action[-1]
|
|
action = action[:-1]
|
|
|
|
current_joint_pos = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
|
|
current_ee_pos = self.fk_function(current_joint_pos)
|
|
desired_ee_pos[:3, 3] = np.clip(
|
|
current_ee_pos[:3, 3] + action,
|
|
self.bounds["min"],
|
|
self.bounds["max"],
|
|
)
|
|
target_joint_pos = self.kinematics.ik(
|
|
current_joint_pos,
|
|
desired_ee_pos,
|
|
position_only=True,
|
|
fk_func=self.fk_function,
|
|
)
|
|
if self.use_gripper:
|
|
target_joint_pos[-1] = gripper_command
|
|
|
|
return target_joint_pos
|
|
|
|
|
|
class EEObservationWrapper(gym.ObservationWrapper):
|
|
"""
|
|
Wrapper that adds end-effector pose information to observations.
|
|
|
|
This wrapper computes the end-effector pose using forward kinematics
|
|
and adds it to the observation space.
|
|
"""
|
|
|
|
def __init__(self, env, ee_pose_limits):
|
|
"""
|
|
Initialize the end-effector observation wrapper.
|
|
|
|
Args:
|
|
env: The environment to wrap.
|
|
ee_pose_limits: Dictionary with 'min' and 'max' keys containing limits for EE pose.
|
|
"""
|
|
super().__init__(env)
|
|
|
|
# Extend observation space to include end effector pose
|
|
prev_space = self.observation_space["observation.state"]
|
|
|
|
self.observation_space["observation.state"] = gym.spaces.Box(
|
|
low=np.concatenate([prev_space.low, ee_pose_limits["min"]]),
|
|
high=np.concatenate([prev_space.high, ee_pose_limits["max"]]),
|
|
shape=(prev_space.shape[0] + 3,),
|
|
dtype=np.float32,
|
|
)
|
|
|
|
# Initialize kinematics instance for the appropriate robot type
|
|
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100")
|
|
self.kinematics = RobotKinematics(robot_type)
|
|
self.fk_function = self.kinematics.fk_gripper_tip
|
|
|
|
def observation(self, observation):
|
|
"""
|
|
Add end-effector pose to the observation.
|
|
|
|
Args:
|
|
observation: Original observation from the environment.
|
|
|
|
Returns:
|
|
Enhanced observation with end-effector pose information.
|
|
"""
|
|
current_joint_pos = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
|
|
current_ee_pos = self.fk_function(current_joint_pos)
|
|
observation["observation.state"] = torch.cat(
|
|
[
|
|
observation["observation.state"],
|
|
torch.from_numpy(current_ee_pos[:3, 3]),
|
|
],
|
|
dim=-1,
|
|
)
|
|
return observation
|
|
|
|
|
|
###########################################################
|
|
# Wrappers related to human intervention and input devices
|
|
###########################################################
|
|
|
|
|
|
class BaseLeaderControlWrapper(gym.Wrapper):
|
|
"""
|
|
Base class for leader-follower robot control wrappers.
|
|
|
|
This wrapper enables human intervention through a leader-follower robot setup,
|
|
where the human can control a leader robot to guide the follower robot's movements.
|
|
"""
|
|
|
|
def __init__(
|
|
self, env, use_geared_leader_arm: bool = False, ee_action_space_params=None, use_gripper=False
|
|
):
|
|
"""
|
|
Initialize the base leader control wrapper.
|
|
|
|
Args:
|
|
env: The environment to wrap.
|
|
use_geared_leader_arm: Whether to use a geared leader arm setup.
|
|
ee_action_space_params: Parameters defining the end-effector action space.
|
|
use_gripper: Whether to include gripper control.
|
|
"""
|
|
super().__init__(env)
|
|
self.robot_leader = env.unwrapped.robot.leader_arms["main"]
|
|
self.robot_follower = env.unwrapped.robot.follower_arms["main"]
|
|
self.use_geared_leader_arm = use_geared_leader_arm
|
|
self.ee_action_space_params = ee_action_space_params
|
|
self.use_ee_action_space = ee_action_space_params is not None
|
|
self.use_gripper: bool = use_gripper
|
|
|
|
# Set up keyboard event tracking
|
|
self._init_keyboard_events()
|
|
self.event_lock = Lock() # Thread-safe access to events
|
|
|
|
# Initialize robot control
|
|
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100")
|
|
self.kinematics = RobotKinematics(robot_type)
|
|
self.prev_leader_ee = None
|
|
self.prev_leader_pos = None
|
|
self.leader_torque_enabled = True
|
|
|
|
# Configure leader arm
|
|
# NOTE: Lower the gains of leader arm for automatic take-over
|
|
# With lower gains we can manually move the leader arm without risk of injury to ourselves or the robot
|
|
# With higher gains, it would be dangerous and difficult to modify the leader's pose while torque is enabled
|
|
# Default value for P_coeff is 32
|
|
self.robot_leader.write("Torque_Enable", 1)
|
|
self.robot_leader.write("P_Coefficient", 4)
|
|
self.robot_leader.write("I_Coefficient", 0)
|
|
self.robot_leader.write("D_Coefficient", 4)
|
|
|
|
self._init_keyboard_listener()
|
|
|
|
def _init_keyboard_events(self):
|
|
"""
|
|
Initialize the keyboard events dictionary.
|
|
|
|
This method sets up tracking for keyboard events used for intervention control.
|
|
It should be overridden in subclasses to add additional events.
|
|
"""
|
|
self.keyboard_events = {
|
|
"episode_success": False,
|
|
"episode_end": False,
|
|
"rerecord_episode": False,
|
|
}
|
|
|
|
def _handle_key_press(self, key, keyboard):
|
|
"""
|
|
Handle key press events.
|
|
|
|
Args:
|
|
key: The key that was pressed.
|
|
keyboard: The keyboard module with key definitions.
|
|
|
|
This method should be overridden in subclasses for additional key handling.
|
|
"""
|
|
try:
|
|
if key == keyboard.Key.esc:
|
|
self.keyboard_events["episode_end"] = True
|
|
return
|
|
if key == keyboard.Key.left:
|
|
self.keyboard_events["rerecord_episode"] = True
|
|
return
|
|
if hasattr(key, "char") and key.char == "s":
|
|
logging.info("Key 's' pressed. Episode success triggered.")
|
|
self.keyboard_events["episode_success"] = True
|
|
return
|
|
except Exception as e:
|
|
logging.error(f"Error handling key press: {e}")
|
|
|
|
def _init_keyboard_listener(self):
|
|
"""
|
|
Initialize the keyboard listener for intervention control.
|
|
|
|
This method sets up keyboard event handling if not in headless mode.
|
|
"""
|
|
if is_headless():
|
|
logging.warning(
|
|
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
|
)
|
|
return
|
|
try:
|
|
from pynput import keyboard
|
|
|
|
def on_press(key):
|
|
with self.event_lock:
|
|
self._handle_key_press(key, keyboard)
|
|
|
|
self.listener = keyboard.Listener(on_press=on_press)
|
|
self.listener.start()
|
|
|
|
except ImportError:
|
|
logging.warning("Could not import pynput. Keyboard interface will not be available.")
|
|
self.listener = None
|
|
|
|
def _check_intervention(self):
|
|
"""
|
|
Check if human intervention is needed.
|
|
|
|
Returns:
|
|
Boolean indicating whether intervention is needed.
|
|
|
|
This method should be overridden in subclasses with specific intervention logic.
|
|
"""
|
|
return False
|
|
|
|
def _handle_intervention(self, action):
|
|
"""
|
|
Process actions during intervention mode.
|
|
|
|
Args:
|
|
action: The original action from the agent.
|
|
|
|
Returns:
|
|
Tuple of (modified_action, intervention_action).
|
|
"""
|
|
if self.leader_torque_enabled:
|
|
self.robot_leader.write("Torque_Enable", 0)
|
|
self.leader_torque_enabled = False
|
|
|
|
leader_pos = self.robot_leader.read("Present_Position")
|
|
follower_pos = self.robot_follower.read("Present_Position")
|
|
|
|
# [:3, 3] Last column of the transformation matrix corresponds to the xyz translation
|
|
leader_ee = self.kinematics.fk_gripper_tip(leader_pos)[:3, 3]
|
|
follower_ee = self.kinematics.fk_gripper_tip(follower_pos)[:3, 3]
|
|
|
|
if self.prev_leader_ee is None:
|
|
self.prev_leader_ee = leader_ee
|
|
|
|
# NOTE: Using the leader's position delta for teleoperation is too noisy
|
|
# Instead, we move the follower to match the leader's absolute position,
|
|
# and record the leader's position changes as the intervention action
|
|
action = leader_ee - follower_ee
|
|
action_intervention = leader_ee - self.prev_leader_ee
|
|
self.prev_leader_ee = leader_ee
|
|
|
|
if self.use_gripper:
|
|
# Get gripper action delta based on leader pose
|
|
leader_gripper = leader_pos[-1]
|
|
follower_gripper = follower_pos[-1]
|
|
gripper_delta = leader_gripper - follower_gripper
|
|
|
|
# Normalize by max angle and quantize to {0,1,2}
|
|
normalized_delta = gripper_delta / MAX_GRIPPER_COMMAND
|
|
if normalized_delta > 0.3:
|
|
gripper_action = 2
|
|
elif normalized_delta < -0.3:
|
|
gripper_action = 0
|
|
else:
|
|
gripper_action = 1
|
|
|
|
action = np.append(action, gripper_action)
|
|
action_intervention = np.append(action_intervention, gripper_delta)
|
|
|
|
return action, action_intervention
|
|
|
|
def _handle_leader_teleoperation(self):
|
|
"""
|
|
Handle leader teleoperation in non-intervention mode.
|
|
|
|
This method synchronizes the leader robot position with the follower.
|
|
"""
|
|
if not self.leader_torque_enabled:
|
|
self.robot_leader.write("Torque_Enable", 1)
|
|
self.leader_torque_enabled = True
|
|
|
|
follower_pos = self.robot_follower.read("Present_Position")
|
|
self.robot_leader.write("Goal_Position", follower_pos)
|
|
|
|
def step(self, action):
|
|
"""
|
|
Execute a step with possible human intervention.
|
|
|
|
Args:
|
|
action: The action to take in the environment.
|
|
|
|
Returns:
|
|
Tuple of (observation, reward, terminated, truncated, info).
|
|
"""
|
|
is_intervention = self._check_intervention()
|
|
action_intervention = None
|
|
|
|
# NOTE:
|
|
if is_intervention:
|
|
action, action_intervention = self._handle_intervention(action)
|
|
else:
|
|
self._handle_leader_teleoperation()
|
|
|
|
# NOTE:
|
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
|
|
# Add intervention info
|
|
info["is_intervention"] = is_intervention
|
|
info["action_intervention"] = action_intervention if is_intervention else None
|
|
|
|
# Check for success or manual termination
|
|
success = self.keyboard_events["episode_success"]
|
|
terminated = terminated or self.keyboard_events["episode_end"] or success
|
|
|
|
if success:
|
|
reward = 1.0
|
|
logging.info("Episode ended successfully with reward 1.0")
|
|
|
|
return obs, reward, terminated, truncated, info
|
|
|
|
def reset(self, **kwargs):
|
|
"""
|
|
Reset the environment and intervention state.
|
|
|
|
Args:
|
|
**kwargs: Keyword arguments passed to the wrapped environment's reset.
|
|
|
|
Returns:
|
|
The initial observation and info.
|
|
"""
|
|
self.prev_leader_ee = None
|
|
self.prev_leader_pos = None
|
|
self.keyboard_events = dict.fromkeys(self.keyboard_events, False)
|
|
return super().reset(**kwargs)
|
|
|
|
def close(self):
|
|
"""
|
|
Clean up resources, including stopping keyboard listener.
|
|
|
|
Returns:
|
|
Result of closing the wrapped environment.
|
|
"""
|
|
if hasattr(self, "listener") and self.listener is not None:
|
|
self.listener.stop()
|
|
return self.env.close()
|
|
|
|
|
|
class GearedLeaderControlWrapper(BaseLeaderControlWrapper):
|
|
"""
|
|
Wrapper that enables manual intervention via keyboard.
|
|
|
|
This wrapper extends the BaseLeaderControlWrapper to allow explicit toggling
|
|
of human intervention mode with keyboard controls.
|
|
"""
|
|
|
|
def _init_keyboard_events(self):
|
|
"""
|
|
Initialize keyboard events including human intervention flag.
|
|
|
|
Extends the base class dictionary with an additional flag for tracking
|
|
intervention state toggled by keyboard.
|
|
"""
|
|
super()._init_keyboard_events()
|
|
self.keyboard_events["human_intervention_step"] = False
|
|
|
|
def _handle_key_press(self, key, keyboard):
|
|
"""
|
|
Handle key presses including space for intervention toggle.
|
|
|
|
Args:
|
|
key: The key that was pressed.
|
|
keyboard: The keyboard module with key definitions.
|
|
|
|
Extends the base handler to respond to space key for toggling intervention.
|
|
"""
|
|
super()._handle_key_press(key, keyboard)
|
|
if key == keyboard.Key.space:
|
|
if not self.keyboard_events["human_intervention_step"]:
|
|
logging.info(
|
|
"Space key pressed. Human intervention required.\n"
|
|
"Place the leader in similar pose to the follower and press space again."
|
|
)
|
|
self.keyboard_events["human_intervention_step"] = True
|
|
log_say("Human intervention step.", play_sounds=True)
|
|
else:
|
|
self.keyboard_events["human_intervention_step"] = False
|
|
logging.info("Space key pressed for a second time.\nContinuing with policy actions.")
|
|
log_say("Continuing with policy actions.", play_sounds=True)
|
|
|
|
def _check_intervention(self):
|
|
"""
|
|
Check if human intervention is active based on keyboard toggle.
|
|
|
|
Returns:
|
|
Boolean indicating whether intervention mode is active.
|
|
"""
|
|
return self.keyboard_events["human_intervention_step"]
|
|
|
|
|
|
class GearedLeaderAutomaticControlWrapper(BaseLeaderControlWrapper):
|
|
"""
|
|
Wrapper with automatic intervention based on error thresholds.
|
|
|
|
This wrapper monitors the error between leader and follower positions
|
|
and automatically triggers intervention when error exceeds thresholds.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
env,
|
|
ee_action_space_params=None,
|
|
use_gripper=False,
|
|
intervention_threshold=1.7,
|
|
release_threshold=0.01,
|
|
queue_size=10,
|
|
):
|
|
"""
|
|
Initialize the automatic intervention wrapper.
|
|
|
|
Args:
|
|
env: The environment to wrap.
|
|
ee_action_space_params: Parameters defining the end-effector action space.
|
|
use_gripper: Whether to include gripper control.
|
|
intervention_threshold: Error threshold to trigger intervention.
|
|
release_threshold: Error threshold to release intervention.
|
|
queue_size: Number of error measurements to track for smoothing.
|
|
"""
|
|
super().__init__(env, ee_action_space_params=ee_action_space_params, use_gripper=use_gripper)
|
|
|
|
# Error tracking parameters
|
|
self.intervention_threshold = intervention_threshold # Threshold to trigger intervention
|
|
self.release_threshold = release_threshold # Threshold to release intervention
|
|
self.queue_size = queue_size # Number of error measurements to keep
|
|
|
|
# Error tracking variables
|
|
self.error_queue = deque(maxlen=self.queue_size)
|
|
self.error_over_time_queue = deque(maxlen=self.queue_size)
|
|
self.previous_error = 0.0
|
|
self.is_intervention_active = False
|
|
self.start_time = time.perf_counter()
|
|
|
|
def _check_intervention(self):
|
|
"""
|
|
Determine if intervention should occur based on leader-follower error.
|
|
|
|
This method monitors the error rate between leader and follower positions
|
|
and automatically triggers intervention when the error rate exceeds
|
|
the intervention threshold, releasing when it falls below the release threshold.
|
|
|
|
Returns:
|
|
Boolean indicating whether intervention should be active.
|
|
"""
|
|
# Skip intervention logic for the first few steps to collect data
|
|
if time.perf_counter() - self.start_time < 1.0: # Wait 1 second before enabling
|
|
return False
|
|
|
|
# Get current positions
|
|
leader_positions = self.robot_leader.read("Present_Position")
|
|
follower_positions = self.robot_follower.read("Present_Position")
|
|
|
|
# Calculate error and error rate
|
|
error = np.linalg.norm(leader_positions - follower_positions)
|
|
error_over_time = np.abs(error - self.previous_error)
|
|
|
|
# Add to queue for running average
|
|
self.error_queue.append(error)
|
|
self.error_over_time_queue.append(error_over_time)
|
|
|
|
# Update previous error
|
|
self.previous_error = error
|
|
|
|
# Calculate averages if we have enough data
|
|
if len(self.error_over_time_queue) >= self.queue_size:
|
|
avg_error_over_time = np.mean(self.error_over_time_queue)
|
|
|
|
# Debug info
|
|
if self.is_intervention_active:
|
|
logging.debug(f"Error rate during intervention: {avg_error_over_time:.4f}")
|
|
|
|
# Determine if intervention should start or stop
|
|
if not self.is_intervention_active and avg_error_over_time > self.intervention_threshold:
|
|
# Transition to intervention mode
|
|
self.is_intervention_active = True
|
|
logging.info(f"Starting automatic intervention: error rate {avg_error_over_time:.4f}")
|
|
|
|
elif self.is_intervention_active and avg_error_over_time < self.release_threshold:
|
|
# End intervention mode
|
|
self.is_intervention_active = False
|
|
logging.info(f"Ending automatic intervention: error rate {avg_error_over_time:.4f}")
|
|
|
|
return self.is_intervention_active
|
|
|
|
def reset(self, **kwargs):
|
|
"""
|
|
Reset error tracking on environment reset.
|
|
|
|
Args:
|
|
**kwargs: Keyword arguments passed to the wrapped environment's reset.
|
|
|
|
Returns:
|
|
The initial observation and info.
|
|
"""
|
|
self.error_queue.clear()
|
|
self.error_over_time_queue.clear()
|
|
self.previous_error = 0.0
|
|
self.is_intervention_active = False
|
|
self.start_time = time.perf_counter()
|
|
return super().reset(**kwargs)
|
|
|
|
|
|
class GamepadControlWrapper(gym.Wrapper):
|
|
"""
|
|
Wrapper that allows controlling a gym environment with a gamepad.
|
|
|
|
This wrapper intercepts the step method and allows human input via gamepad
|
|
to override the agent's actions when desired.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
env,
|
|
x_step_size=1.0,
|
|
y_step_size=1.0,
|
|
z_step_size=1.0,
|
|
use_gripper=False,
|
|
auto_reset=False,
|
|
input_threshold=0.001,
|
|
):
|
|
"""
|
|
Initialize the gamepad controller wrapper.
|
|
|
|
Args:
|
|
env: The environment to wrap.
|
|
x_step_size: Base movement step size for X axis in meters.
|
|
y_step_size: Base movement step size for Y axis in meters.
|
|
z_step_size: Base movement step size for Z axis in meters.
|
|
use_gripper: Whether to include gripper control.
|
|
auto_reset: Whether to auto reset the environment when episode ends.
|
|
input_threshold: Minimum movement delta to consider as active input.
|
|
"""
|
|
super().__init__(env)
|
|
from lerobot.scripts.server.end_effector_control_utils import (
|
|
GamepadController,
|
|
GamepadControllerHID,
|
|
)
|
|
|
|
# use HidApi for macos
|
|
if sys.platform == "darwin":
|
|
self.controller = GamepadControllerHID(
|
|
x_step_size=x_step_size,
|
|
y_step_size=y_step_size,
|
|
z_step_size=z_step_size,
|
|
)
|
|
else:
|
|
self.controller = GamepadController(
|
|
x_step_size=x_step_size,
|
|
y_step_size=y_step_size,
|
|
z_step_size=z_step_size,
|
|
)
|
|
self.auto_reset = auto_reset
|
|
self.use_gripper = use_gripper
|
|
self.input_threshold = input_threshold
|
|
self.controller.start()
|
|
|
|
logging.info("Gamepad control wrapper initialized")
|
|
print("Gamepad controls:")
|
|
print(" Left analog stick: Move in X-Y plane")
|
|
print(" Right analog stick: Move in Z axis (up/down)")
|
|
print(" X/Square button: End episode (FAILURE)")
|
|
print(" Y/Triangle button: End episode (SUCCESS)")
|
|
print(" B/Circle button: Exit program")
|
|
|
|
def get_gamepad_action(
|
|
self,
|
|
) -> Tuple[bool, np.ndarray, bool, bool, bool]:
|
|
"""
|
|
Get the current action from the gamepad if any input is active.
|
|
|
|
Returns:
|
|
Tuple containing:
|
|
- is_active: Whether gamepad input is active
|
|
- action: The action derived from gamepad input
|
|
- terminate_episode: Whether episode termination was requested
|
|
- success: Whether episode success was signaled
|
|
- rerecord_episode: Whether episode rerecording was requested
|
|
"""
|
|
# Update the controller to get fresh inputs
|
|
self.controller.update()
|
|
|
|
# Get movement deltas from the controller
|
|
delta_x, delta_y, delta_z = self.controller.get_deltas()
|
|
|
|
intervention_is_active = self.controller.should_intervene()
|
|
|
|
# Create action from gamepad input
|
|
gamepad_action = np.array([delta_x, delta_y, delta_z], dtype=np.float32)
|
|
|
|
if self.use_gripper:
|
|
gripper_command = self.controller.gripper_command()
|
|
if gripper_command == "open":
|
|
gamepad_action = np.concatenate([gamepad_action, [2.0]])
|
|
elif gripper_command == "close":
|
|
gamepad_action = np.concatenate([gamepad_action, [0.0]])
|
|
else:
|
|
gamepad_action = np.concatenate([gamepad_action, [1.0]])
|
|
|
|
# Check episode ending buttons
|
|
# We'll rely on controller.get_episode_end_status() which returns "success", "failure", or None
|
|
episode_end_status = self.controller.get_episode_end_status()
|
|
terminate_episode = episode_end_status is not None
|
|
success = episode_end_status == "success"
|
|
rerecord_episode = episode_end_status == "rerecord_episode"
|
|
|
|
return (
|
|
intervention_is_active,
|
|
gamepad_action,
|
|
terminate_episode,
|
|
success,
|
|
rerecord_episode,
|
|
)
|
|
|
|
def step(self, action):
|
|
"""
|
|
Step the environment, using gamepad input to override actions when active.
|
|
|
|
Args:
|
|
action: Original action from agent.
|
|
|
|
Returns:
|
|
Tuple of (observation, reward, terminated, truncated, info).
|
|
"""
|
|
# Get gamepad state and action
|
|
(
|
|
is_intervention,
|
|
gamepad_action,
|
|
terminate_episode,
|
|
success,
|
|
rerecord_episode,
|
|
) = self.get_gamepad_action()
|
|
|
|
# Update episode ending state if requested
|
|
if terminate_episode:
|
|
logging.info(f"Episode manually ended: {'SUCCESS' if success else 'FAILURE'}")
|
|
|
|
# Only override the action if gamepad is active
|
|
action = gamepad_action if is_intervention else action
|
|
|
|
# Step the environment
|
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
|
|
# Add episode ending if requested via gamepad
|
|
terminated = terminated or truncated or terminate_episode
|
|
|
|
if success:
|
|
reward = 1.0
|
|
logging.info("Episode ended successfully with reward 1.0")
|
|
|
|
if isinstance(action, np.ndarray):
|
|
action = torch.from_numpy(action)
|
|
|
|
info["is_intervention"] = is_intervention
|
|
info["action_intervention"] = action
|
|
info["rerecord_episode"] = rerecord_episode
|
|
|
|
# If episode ended, reset the state
|
|
if terminated or truncated:
|
|
# Add success/failure information to info dict
|
|
info["next.success"] = success
|
|
|
|
# Auto reset if configured
|
|
if self.auto_reset:
|
|
obs, reset_info = self.reset()
|
|
info.update(reset_info)
|
|
|
|
return obs, reward, terminated, truncated, info
|
|
|
|
def close(self):
|
|
"""
|
|
Clean up resources when environment closes.
|
|
|
|
Returns:
|
|
Result of closing the wrapped environment.
|
|
"""
|
|
# Stop the controller
|
|
if hasattr(self, "controller"):
|
|
self.controller.stop()
|
|
|
|
# Call the parent close method
|
|
return self.env.close()
|
|
|
|
|
|
class GymHilDeviceWrapper(gym.Wrapper):
|
|
def __init__(self, env, device="cpu"):
|
|
super().__init__(env)
|
|
self.device = device
|
|
|
|
def step(self, action):
|
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
for k in obs:
|
|
obs[k] = obs[k].to(self.device)
|
|
if "action_intervention" in info:
|
|
info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device)
|
|
return obs, reward, terminated, truncated, info
|
|
|
|
def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None):
|
|
obs, info = self.env.reset(seed=seed, options=options)
|
|
for k in obs:
|
|
obs[k] = obs[k].to(self.device)
|
|
if "action_intervention" in info:
|
|
info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device)
|
|
return obs, info
|
|
|
|
|
|
class GymHilObservationProcessorWrapper(gym.ObservationWrapper):
|
|
def __init__(self, env: gym.Env):
|
|
super().__init__(env)
|
|
prev_space = self.observation_space
|
|
new_space = {}
|
|
|
|
for key in prev_space:
|
|
if "pixels" in key:
|
|
for k in prev_space["pixels"]:
|
|
new_space[f"observation.images.{k}"] = gym.spaces.Box(
|
|
0.0, 255.0, shape=(3, 128, 128), dtype=np.uint8
|
|
)
|
|
|
|
if key == "agent_pos":
|
|
new_space["observation.state"] = prev_space["agent_pos"]
|
|
|
|
self.observation_space = gym.spaces.Dict(new_space)
|
|
|
|
def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
|
|
return preprocess_observation(observation)
|
|
|
|
|
|
###########################################################
|
|
# Factory functions
|
|
###########################################################
|
|
|
|
|
|
def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
|
"""
|
|
Factory function to create a vectorized robot environment.
|
|
|
|
This function builds a robot environment with all necessary wrappers
|
|
based on the provided configuration.
|
|
|
|
Args:
|
|
cfg: Configuration object containing environment parameters.
|
|
|
|
Returns:
|
|
|
|
A vectorized gym environment with all necessary wrappers applied.
|
|
"""
|
|
if cfg.type == "hil":
|
|
import gymnasium as gym
|
|
|
|
# TODO (azouitine)
|
|
env = gym.make(
|
|
f"gym_hil/{cfg.task}",
|
|
image_obs=True,
|
|
render_mode="human",
|
|
step_size=cfg.wrapper.ee_action_space_params.x_step_size,
|
|
use_gripper=cfg.wrapper.use_gripper,
|
|
gripper_penalty=cfg.wrapper.gripper_penalty,
|
|
)
|
|
env = GymHilObservationProcessorWrapper(env=env)
|
|
env = GymHilDeviceWrapper(env=env, device=cfg.device)
|
|
env = BatchCompatibleWrapper(env=env)
|
|
env = TorchActionWrapper(env=env, device=cfg.device)
|
|
return env
|
|
|
|
robot = make_robot_from_config(cfg.robot)
|
|
# Create base environment
|
|
env = RobotEnv(
|
|
robot=robot,
|
|
display_cameras=cfg.wrapper.display_cameras,
|
|
)
|
|
|
|
# Add observation and image processing
|
|
if cfg.wrapper.add_joint_velocity_to_observation:
|
|
env = AddJointVelocityToObservation(env=env, fps=cfg.fps)
|
|
if cfg.wrapper.add_current_to_observation:
|
|
env = AddCurrentToObservation(env=env)
|
|
if cfg.wrapper.add_ee_pose_to_observation:
|
|
env = EEObservationWrapper(env=env, ee_pose_limits=cfg.wrapper.ee_action_space_params.bounds)
|
|
|
|
env = ConvertToLeRobotObservation(env=env, device=cfg.device)
|
|
|
|
if cfg.wrapper.crop_params_dict is not None:
|
|
env = ImageCropResizeWrapper(
|
|
env=env,
|
|
crop_params_dict=cfg.wrapper.crop_params_dict,
|
|
resize_size=cfg.wrapper.resize_size,
|
|
)
|
|
|
|
# Add reward computation and control wrappers
|
|
reward_classifier = init_reward_classifier(cfg)
|
|
if reward_classifier is not None:
|
|
env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
|
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
|
if cfg.wrapper.use_gripper:
|
|
env = GripperActionWrapper(env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold)
|
|
if cfg.wrapper.gripper_penalty is not None:
|
|
env = GripperPenaltyWrapper(
|
|
env=env,
|
|
penalty=cfg.wrapper.gripper_penalty,
|
|
)
|
|
|
|
env = EEActionWrapper(
|
|
env=env,
|
|
ee_action_space_params=cfg.wrapper.ee_action_space_params,
|
|
use_gripper=cfg.wrapper.use_gripper,
|
|
)
|
|
|
|
if cfg.wrapper.ee_action_space_params.control_mode == "gamepad":
|
|
env = GamepadControlWrapper(
|
|
env=env,
|
|
x_step_size=cfg.wrapper.ee_action_space_params.x_step_size,
|
|
y_step_size=cfg.wrapper.ee_action_space_params.y_step_size,
|
|
z_step_size=cfg.wrapper.ee_action_space_params.z_step_size,
|
|
use_gripper=cfg.wrapper.use_gripper,
|
|
)
|
|
elif cfg.wrapper.ee_action_space_params.control_mode == "leader":
|
|
env = GearedLeaderControlWrapper(
|
|
env=env,
|
|
ee_action_space_params=cfg.wrapper.ee_action_space_params,
|
|
use_gripper=cfg.wrapper.use_gripper,
|
|
)
|
|
elif cfg.wrapper.ee_action_space_params.control_mode == "leader_automatic":
|
|
env = GearedLeaderAutomaticControlWrapper(
|
|
env=env,
|
|
ee_action_space_params=cfg.wrapper.ee_action_space_params,
|
|
use_gripper=cfg.wrapper.use_gripper,
|
|
)
|
|
else:
|
|
raise ValueError(f"Invalid control mode: {cfg.wrapper.ee_action_space_params.control_mode}")
|
|
|
|
env = ResetWrapper(
|
|
env=env,
|
|
reset_pose=cfg.wrapper.fixed_reset_joint_positions,
|
|
reset_time_s=cfg.wrapper.reset_time_s,
|
|
)
|
|
env = BatchCompatibleWrapper(env=env)
|
|
env = TorchActionWrapper(env=env, device=cfg.device)
|
|
|
|
return env
|
|
|
|
|
|
def init_reward_classifier(cfg):
|
|
"""
|
|
Load a reward classifier policy from a pretrained path if configured.
|
|
|
|
Args:
|
|
cfg: The environment configuration containing classifier paths.
|
|
|
|
Returns:
|
|
The loaded classifier model or None if not configured.
|
|
"""
|
|
if cfg.reward_classifier_pretrained_path is None:
|
|
return None
|
|
|
|
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
|
|
|
# Get device from config or default to CUDA
|
|
device = getattr(cfg, "device", "cpu")
|
|
|
|
# Load the classifier directly using from_pretrained
|
|
classifier = Classifier.from_pretrained(
|
|
pretrained_name_or_path=cfg.reward_classifier_pretrained_path,
|
|
)
|
|
|
|
# Ensure model is on the correct device
|
|
classifier.to(device)
|
|
classifier.eval() # Set to evaluation mode
|
|
|
|
return classifier
|
|
|
|
|
|
###########################################################
|
|
# Record and replay functions
|
|
###########################################################
|
|
|
|
|
|
def record_dataset(env, policy, cfg, success_collection_steps=15):
|
|
"""
|
|
Record a dataset of robot interactions using either a policy or teleop.
|
|
|
|
This function runs episodes in the environment and records the observations,
|
|
actions, and results for dataset creation.
|
|
|
|
Args:
|
|
env: The environment to record from.
|
|
policy: Optional policy to generate actions (if None, uses teleop).
|
|
cfg: Configuration object containing recording parameters like:
|
|
- repo_id: Repository ID for dataset storage
|
|
- dataset_root: Local root directory for dataset
|
|
- num_episodes: Number of episodes to record
|
|
- fps: Frames per second for recording
|
|
- push_to_hub: Whether to push dataset to Hugging Face Hub
|
|
- task: Name/description of the task being recorded
|
|
success_collection_steps: Number of additional steps to continue recording after
|
|
a success (reward=1) is detected. This helps collect
|
|
more positive examples for reward classifier training.
|
|
"""
|
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
|
|
# Setup initial action (zero action if using teleop)
|
|
action = env.action_space.sample() * 0.0
|
|
|
|
# Configure dataset features based on environment spaces
|
|
features = {
|
|
"observation.state": {
|
|
"dtype": "float32",
|
|
"shape": env.observation_space["observation.state"].shape,
|
|
"names": None,
|
|
},
|
|
"action": {
|
|
"dtype": "float32",
|
|
"shape": env.action_space.shape,
|
|
"names": None,
|
|
},
|
|
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
|
|
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
|
|
"complementary_info.discrete_penalty": {
|
|
"dtype": "float32",
|
|
"shape": (1,),
|
|
"names": ["discrete_penalty"],
|
|
},
|
|
}
|
|
|
|
# Add image features
|
|
for key in env.observation_space:
|
|
if "image" in key:
|
|
features[key] = {
|
|
"dtype": "video",
|
|
"shape": env.observation_space[key].shape,
|
|
"names": None,
|
|
}
|
|
|
|
# Create dataset
|
|
dataset = LeRobotDataset.create(
|
|
cfg.repo_id,
|
|
cfg.fps,
|
|
root=cfg.dataset_root,
|
|
use_videos=True,
|
|
image_writer_threads=4,
|
|
image_writer_processes=0,
|
|
features=features,
|
|
)
|
|
|
|
# Record episodes
|
|
episode_index = 0
|
|
recorded_action = None
|
|
while episode_index < cfg.num_episodes:
|
|
obs, _ = env.reset()
|
|
start_episode_t = time.perf_counter()
|
|
log_say(f"Recording episode {episode_index}", play_sounds=True)
|
|
|
|
# Track success state collection
|
|
success_detected = False
|
|
success_steps_collected = 0
|
|
|
|
# Run episode steps
|
|
while time.perf_counter() - start_episode_t < cfg.wrapper.control_time_s:
|
|
start_loop_t = time.perf_counter()
|
|
|
|
# Get action from policy if available
|
|
if cfg.pretrained_policy_name_or_path is not None:
|
|
action = policy.select_action(obs)
|
|
|
|
# Step environment
|
|
obs, reward, terminated, truncated, info = env.step(action)
|
|
|
|
# Check if episode needs to be rerecorded
|
|
if info.get("rerecord_episode", False):
|
|
break
|
|
|
|
# For teleop, get action from intervention
|
|
recorded_action = {
|
|
"action": info["action_intervention"].cpu().squeeze(0).float() if policy is None else action
|
|
}
|
|
|
|
# Process observation for dataset
|
|
obs_processed = {k: v.cpu().squeeze(0).float() for k, v in obs.items()}
|
|
|
|
# Check if we've just detected success
|
|
if reward == 1.0 and not success_detected:
|
|
success_detected = True
|
|
logging.info("Success detected! Collecting additional success states.")
|
|
|
|
# Add frame to dataset - continue marking as success even during extra collection steps
|
|
frame = {**obs_processed, **recorded_action}
|
|
|
|
# If we're in the success collection phase, keep marking rewards as 1.0
|
|
if success_detected:
|
|
frame["next.reward"] = np.array([1.0], dtype=np.float32)
|
|
else:
|
|
frame["next.reward"] = np.array([reward], dtype=np.float32)
|
|
|
|
# Only mark as done if we're truly done (reached end or collected enough success states)
|
|
really_done = terminated or truncated
|
|
if success_detected:
|
|
success_steps_collected += 1
|
|
really_done = success_steps_collected >= success_collection_steps
|
|
|
|
frame["next.done"] = np.array([really_done], dtype=bool)
|
|
frame["task"] = cfg.task
|
|
frame["complementary_info.discrete_penalty"] = torch.tensor(
|
|
[info.get("discrete_penalty", 0.0)], dtype=torch.float32
|
|
)
|
|
dataset.add_frame(frame)
|
|
|
|
# Maintain consistent timing
|
|
if cfg.fps:
|
|
dt_s = time.perf_counter() - start_loop_t
|
|
busy_wait(1 / cfg.fps - dt_s)
|
|
|
|
# Check if we should end the episode
|
|
if (terminated or truncated) and not success_detected:
|
|
# Regular termination without success
|
|
break
|
|
elif success_detected and success_steps_collected >= success_collection_steps:
|
|
# We've collected enough success states
|
|
logging.info(f"Collected {success_steps_collected} additional success states")
|
|
break
|
|
|
|
# Handle episode recording
|
|
if info.get("rerecord_episode", False):
|
|
dataset.clear_episode_buffer()
|
|
logging.info(f"Re-recording episode {episode_index}")
|
|
continue
|
|
|
|
dataset.save_episode(cfg.task)
|
|
episode_index += 1
|
|
|
|
# Finalize dataset
|
|
# dataset.consolidate(run_compute_stats=True)
|
|
if cfg.push_to_hub:
|
|
dataset.push_to_hub()
|
|
|
|
|
|
def replay_episode(env, cfg):
|
|
"""
|
|
Replay a recorded episode in the environment.
|
|
|
|
This function loads actions from a previously recorded episode
|
|
and executes them in the environment.
|
|
|
|
Args:
|
|
env: The environment to replay in.
|
|
cfg: Configuration object containing replay parameters:
|
|
- repo_id: Repository ID for dataset
|
|
- dataset_root: Local root directory for dataset
|
|
- episode: Episode ID to replay
|
|
"""
|
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
|
|
dataset = LeRobotDataset(cfg.repo_id, root=cfg.dataset_root, episodes=[cfg.episode])
|
|
env.reset()
|
|
|
|
actions = dataset.hf_dataset.select_columns("action")
|
|
|
|
for idx in range(dataset.num_frames):
|
|
start_episode_t = time.perf_counter()
|
|
|
|
action = actions[idx]["action"]
|
|
env.step(action)
|
|
|
|
dt_s = time.perf_counter() - start_episode_t
|
|
busy_wait(1 / 10 - dt_s)
|
|
|
|
|
|
@parser.wrap()
|
|
def main(cfg: EnvConfig):
|
|
"""
|
|
Main entry point for the robot environment script.
|
|
|
|
This function runs the robot environment in one of several modes
|
|
based on the provided configuration.
|
|
|
|
Args:
|
|
cfg: Configuration object defining the run parameters,
|
|
including mode (record, replay, random) and other settings.
|
|
"""
|
|
env = make_robot_env(cfg)
|
|
|
|
if cfg.mode == "record":
|
|
policy = None
|
|
if cfg.pretrained_policy_name_or_path is not None:
|
|
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
|
|
|
policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path)
|
|
policy.to(cfg.device)
|
|
policy.eval()
|
|
|
|
# Get success_collection_steps from config or default to 15
|
|
record_dataset(
|
|
env,
|
|
policy=policy,
|
|
cfg=cfg,
|
|
success_collection_steps=15,
|
|
)
|
|
exit()
|
|
|
|
if cfg.mode == "replay":
|
|
replay_episode(
|
|
env,
|
|
cfg=cfg,
|
|
)
|
|
exit()
|
|
|
|
env.reset()
|
|
|
|
# Initialize the smoothed action as a random sample.
|
|
smoothed_action = env.action_space.sample()
|
|
|
|
# Smoothing coefficient (alpha) defines how much of the new random sample to mix in.
|
|
# A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth.
|
|
alpha = 1.0
|
|
|
|
num_episode = 0
|
|
successes = []
|
|
while num_episode < 10:
|
|
start_loop_s = time.perf_counter()
|
|
# Sample a new random action from the robot's action space.
|
|
new_random_action = env.action_space.sample()
|
|
# Update the smoothed action using an exponential moving average.
|
|
smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
|
|
|
|
# Execute the step: wrap the NumPy action in a torch tensor.
|
|
obs, reward, terminated, truncated, info = env.step(smoothed_action)
|
|
if terminated or truncated:
|
|
successes.append(reward)
|
|
env.reset()
|
|
num_episode += 1
|
|
|
|
dt_s = time.perf_counter() - start_loop_s
|
|
busy_wait(1 / cfg.fps - dt_s)
|
|
|
|
logging.info(f"Success after 20 steps {successes}")
|
|
logging.info(f"success rate {sum(successes) / len(successes)}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|