General fixes to abide by the new config in learner_server, actor_server, gym_manipulator

This commit is contained in:
Michel Aractingi
2025-05-27 15:49:33 +02:00
committed by AdilZouitine
parent df96e5b3b2
commit 1edfbf792a
8 changed files with 49 additions and 44 deletions

View File

@@ -57,7 +57,7 @@ class AlohaEnv(EnvConfig):
features_map: dict[str, str] = field( features_map: dict[str, str] = field(
default_factory=lambda: { default_factory=lambda: {
"action": ACTION, "action": ACTION,
"agent_pos": OBS_ROBOT, "agent_pos": OBS_STATE,
"top": f"{OBS_IMAGE}.top", "top": f"{OBS_IMAGE}.top",
"pixels/top": f"{OBS_IMAGES}.top", "pixels/top": f"{OBS_IMAGES}.top",
} }

View File

@@ -1,3 +0,0 @@
from lerobot.common.model.kinematics_utils import RobotKinematics
__all__ = ["RobotKinematics"]

View File

@@ -1,4 +1,3 @@
from .config import RobotConfig from .config import RobotConfig
from .robot import Robot from .robot import Robot
from .robot_wrapper import RobotWrapper
from .utils import make_robot_from_config from .utils import make_robot_from_config

View File

@@ -66,7 +66,7 @@ class SO100FollowerEndEffector(SO100Follower):
self.kinematics = RobotKinematics(robot_type="so101") self.kinematics = RobotKinematics(robot_type="so101")
# Set the forward kinematics function # Set the forward kinematics function
self.fk_function = self.kinematics.fk_gripper_tip self.fk_function = self.kinematics.fk_gripper
# Store the bounds for end-effector position # Store the bounds for end-effector position
self.end_effector_bounds = self.config.end_effector_bounds self.end_effector_bounds = self.config.end_effector_bounds
@@ -152,16 +152,16 @@ class SO100FollowerEndEffector(SO100Follower):
fk_func=self.fk_function, fk_func=self.fk_function,
) )
target_joint_values_in_degrees = np.clip(target_joint_values_in_degrees, -180.0, 180.0)
# Create joint space action dictionary # Create joint space action dictionary
joint_action = { joint_action = {
f"{key}.pos": target_joint_values_in_degrees[i] f"{key}.pos": target_joint_values_in_degrees[i] for i, key in enumerate(self.bus.motors.keys())
for i, key in enumerate(self.bus.motors.keys())
} }
# Handle gripper separately if included in action # Handle gripper separately if included in action
joint_action["gripper.pos"] = np.clip( joint_action["gripper.pos"] = np.clip(
self.current_joint_pos[-1] + (action[-1] - 1) * self.config.max_gripper_pos, self.current_joint_pos[-1] + (action[-1] - 1) * self.config.max_gripper_pos,
0, 5,
self.config.max_gripper_pos, self.config.max_gripper_pos,
) )
@@ -191,3 +191,7 @@ class SO100FollowerEndEffector(SO100Follower):
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
return obs_dict return obs_dict
def reset(self):
self.current_ee_pos = None
self.current_joint_pos = None

View File

@@ -16,7 +16,6 @@
import logging import logging
import time import time
from typing import Any
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode
@@ -28,7 +27,6 @@ from lerobot.common.motors.feetech import (
from ..teleoperator import Teleoperator from ..teleoperator import Teleoperator
from .config_so101_leader import SO101LeaderConfig from .config_so101_leader import SO101LeaderConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -141,4 +139,4 @@ class SO101Leader(Teleoperator):
DeviceNotConnectedError(f"{self} is not connected.") DeviceNotConnectedError(f"{self} is not connected.")
self.bus.disconnect() self.bus.disconnect()
logger.info(f"{self} disconnected.") logger.info(f"{self} disconnected.")

View File

@@ -80,10 +80,13 @@ import torch
from torch import nn from torch import nn
from torch.multiprocessing import Event, Queue from torch.multiprocessing import Event, Queue
from lerobot.common.cameras import opencv # noqa: F401
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.robot_devices.utils import busy_wait from lerobot.common.robots import so100_follower_end_effector # noqa: F401
from lerobot.common.teleoperators import gamepad, so100_leader # noqa: F401
from lerobot.common.utils.random_utils import set_seed from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.robot_utils import busy_wait
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
TimerManager, TimerManager,
get_safe_torch_device, get_safe_torch_device,

View File

@@ -25,7 +25,7 @@ import numpy as np
import torch import torch
import torchvision.transforms.functional as F # noqa: N812 import torchvision.transforms.functional as F # noqa: N812
from lerobot.common.cameras import intel, opencv # noqa: F401 from lerobot.common.cameras import opencv # noqa: F401
from lerobot.common.envs.configs import EnvConfig from lerobot.common.envs.configs import EnvConfig
from lerobot.common.envs.utils import preprocess_observation from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.model.kinematics import RobotKinematics from lerobot.common.model.kinematics import RobotKinematics
@@ -37,9 +37,7 @@ from lerobot.common.robots import ( # noqa: F401
from lerobot.common.teleoperators import ( from lerobot.common.teleoperators import (
gamepad, # noqa: F401 gamepad, # noqa: F401
make_teleoperator_from_config, make_teleoperator_from_config,
so101_leader,
) )
from lerobot.common.teleoperators.gamepad.configuration_gamepad import GamepadTeleopConfig
from lerobot.common.teleoperators.gamepad.teleop_gamepad import GamepadTeleop from lerobot.common.teleoperators.gamepad.teleop_gamepad import GamepadTeleop
from lerobot.common.utils.robot_utils import busy_wait from lerobot.common.utils.robot_utils import busy_wait
from lerobot.common.utils.utils import log_say from lerobot.common.utils.utils import log_say
@@ -307,6 +305,8 @@ class RobotEnv(gym.Env):
""" """
super().reset(seed=seed, options=options) super().reset(seed=seed, options=options)
self.robot.reset()
# Capture the initial observation. # Capture the initial observation.
observation = self._get_observation() observation = self._get_observation()
@@ -1003,7 +1003,9 @@ class GripperActionWrapper(gym.ActionWrapper):
gripper_state = self.unwrapped.robot.bus.sync_read("Present_Position")["gripper"] gripper_state = self.unwrapped.robot.bus.sync_read("Present_Position")["gripper"]
gripper_action_value = np.clip(gripper_state + gripper_command, 0, self.unwrapped.robot.config.max_gripper_pos) gripper_action_value = np.clip(
gripper_state + gripper_command, 0, self.unwrapped.robot.config.max_gripper_pos
)
action[-1] = gripper_action_value.item() action[-1] = gripper_action_value.item()
return action return action
@@ -1138,7 +1140,7 @@ class EEObservationWrapper(gym.ObservationWrapper):
) )
# Initialize kinematics instance for the appropriate robot type # Initialize kinematics instance for the appropriate robot type
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100") robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so101")
self.kinematics = RobotKinematics(robot_type) self.kinematics = RobotKinematics(robot_type)
self.fk_function = self.kinematics.fk_gripper_tip self.fk_function = self.kinematics.fk_gripper_tip
@@ -1152,16 +1154,10 @@ class EEObservationWrapper(gym.ObservationWrapper):
Returns: Returns:
Enhanced observation with end-effector pose information. Enhanced observation with end-effector pose information.
""" """
current_joint_pos = self.unwrapped._get_observation()["observation.state"] current_joint_pos = self.unwrapped._get_observation()["agent_pos"]
current_ee_pos = self.fk_function(current_joint_pos) current_ee_pos = self.fk_function(current_joint_pos)
observation["observation.state"] = torch.cat( observation["agent_pos"] = np.concatenate([observation["agent_pos"], current_ee_pos[:3, 3]], -1)
[
observation["observation.state"],
torch.from_numpy(current_ee_pos[:3, 3]),
],
dim=-1,
)
return observation return observation
@@ -1178,9 +1174,7 @@ class BaseLeaderControlWrapper(gym.Wrapper):
where the human can control a leader robot to guide the follower robot's movements. where the human can control a leader robot to guide the follower robot's movements.
""" """
def __init__( def __init__(self, env, teleop_device, use_geared_leader_arm: bool = False, use_gripper=False):
self, env, teleop_device, use_geared_leader_arm: bool = False, use_gripper=False
):
""" """
Initialize the base leader control wrapper. Initialize the base leader control wrapper.
@@ -1322,10 +1316,12 @@ class BaseLeaderControlWrapper(gym.Wrapper):
if self.use_gripper: if self.use_gripper:
if self.prev_leader_gripper is None: if self.prev_leader_gripper is None:
self.prev_leader_gripper = np.clip(leader_pos[-1], 0, self.robot_follower.config.max_gripper_pos) self.prev_leader_gripper = np.clip(
leader_pos[-1], 0, self.robot_follower.config.max_gripper_pos
)
# Get gripper action delta based on leader pose # Get gripper action delta based on leader pose
leader_gripper = leader_pos[-1] leader_gripper = leader_pos[-1]
# follower_gripper = follower_pos[-1] # follower_gripper = follower_pos[-1]
gripper_delta = leader_gripper - self.prev_leader_gripper gripper_delta = leader_gripper - self.prev_leader_gripper
@@ -1342,7 +1338,7 @@ class BaseLeaderControlWrapper(gym.Wrapper):
action = np.append(action, gripper_action) action = np.append(action, gripper_action)
# action_intervention = np.append(action_intervention, gripper_delta) # action_intervention = np.append(action_intervention, gripper_delta)
return action # , action_intervention return action # , action_intervention
def _handle_leader_teleoperation(self): def _handle_leader_teleoperation(self):
""" """
@@ -1389,7 +1385,11 @@ class BaseLeaderControlWrapper(gym.Wrapper):
info["is_intervention"] = is_intervention info["is_intervention"] = is_intervention
info["action_intervention"] = action if is_intervention else None info["action_intervention"] = action if is_intervention else None
self.prev_leader_gripper = np.clip(self.robot_leader.bus.sync_read("Present_Position")["gripper"], 0, self.robot_follower.config.max_gripper_pos) self.prev_leader_gripper = np.clip(
self.robot_leader.bus.sync_read("Present_Position")["gripper"],
0,
self.robot_follower.config.max_gripper_pos,
)
# Check for success or manual termination # Check for success or manual termination
success = self.keyboard_events["episode_success"] success = self.keyboard_events["episode_success"]
@@ -1569,7 +1569,7 @@ class GearedLeaderAutomaticControlWrapper(BaseLeaderControlWrapper):
self.ee_error_over_time_queue[-1] - self.previous_ee_error_over_time_over_time self.ee_error_over_time_queue[-1] - self.previous_ee_error_over_time_over_time
) )
self.previous_ee_error_over_time_over_time = self.ee_error_over_time_queue[-1] self.previous_ee_error_over_time_over_time = self.ee_error_over_time_queue[-1]
self.intervention_threshold = 0.02 self.intervention_threshold = 0.02
# Determine if intervention should start or stop based on the thresholds set in the constructor # Determine if intervention should start or stop based on the thresholds set in the constructor
@@ -1890,8 +1890,8 @@ def make_robot_env(cfg: EnvConfig) -> gym.Env:
env = AddJointVelocityToObservation(env=env, fps=cfg.fps) env = AddJointVelocityToObservation(env=env, fps=cfg.fps)
if cfg.wrapper.add_current_to_observation: if cfg.wrapper.add_current_to_observation:
env = AddCurrentToObservation(env=env) env = AddCurrentToObservation(env=env)
if False and cfg.wrapper.add_ee_pose_to_observation: if cfg.wrapper.add_ee_pose_to_observation:
env = EEObservationWrapper(env=env, ee_pose_limits=cfg.wrapper.ee_action_space_params.bounds) env = EEObservationWrapper(env=env, ee_pose_limits=robot.end_effector_bounds)
env = ConvertToLeRobotObservation(env=env, device=cfg.device) env = ConvertToLeRobotObservation(env=env, device=cfg.device)
@@ -1910,9 +1910,9 @@ def make_robot_env(cfg: EnvConfig) -> gym.Env:
if cfg.wrapper: if cfg.wrapper:
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps) env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
if cfg.wrapper.use_gripper: if cfg.wrapper.use_gripper:
env = GripperActionWrapper( # env = GripperActionWrapper(
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold # env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
) # )
if cfg.wrapper.gripper_penalty is not None: if cfg.wrapper.gripper_penalty is not None:
env = GripperPenaltyWrapper( env = GripperPenaltyWrapper(
env=env, env=env,
@@ -1922,7 +1922,9 @@ def make_robot_env(cfg: EnvConfig) -> gym.Env:
# Control mode specific wrappers # Control mode specific wrappers
control_mode = cfg.wrapper.control_mode control_mode = cfg.wrapper.control_mode
if control_mode == "gamepad": if control_mode == "gamepad":
assert isinstance(teleop_device, GamepadTeleop), "teleop_device must be an instance of GamepadTeleop for gamepad control mode" assert isinstance(teleop_device, GamepadTeleop), (
"teleop_device must be an instance of GamepadTeleop for gamepad control mode"
)
env = GamepadControlWrapper( env = GamepadControlWrapper(
env=env, env=env,
teleop_device=teleop_device, teleop_device=teleop_device,
@@ -2117,11 +2119,10 @@ def record_dataset(env, policy, cfg):
really_done = success_steps_collected >= cfg.number_of_steps_after_success really_done = success_steps_collected >= cfg.number_of_steps_after_success
frame["next.done"] = np.array([really_done], dtype=bool) frame["next.done"] = np.array([really_done], dtype=bool)
frame["task"] = cfg.task
frame["complementary_info.discrete_penalty"] = torch.tensor( frame["complementary_info.discrete_penalty"] = torch.tensor(
[info.get("discrete_penalty", 0.0)], dtype=torch.float32 [info.get("discrete_penalty", 0.0)], dtype=torch.float32
) )
dataset.add_frame(frame) dataset.add_frame(frame, task=cfg.task)
# Maintain consistent timing # Maintain consistent timing
if cfg.fps: if cfg.fps:
@@ -2233,7 +2234,7 @@ def main(cfg: EnvConfig):
while num_episode < 10: while num_episode < 10:
start_loop_s = time.perf_counter() start_loop_s = time.perf_counter()
# Sample a new random action from the robot's action space. # Sample a new random action from the robot's action space.
new_random_action = env.action_space.sample() new_random_action = env.action_space.sample() * 0.0
# Update the smoothed action using an exponential moving average. # Update the smoothed action using an exponential moving average.
smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action

View File

@@ -91,6 +91,7 @@ from torch import nn
from torch.multiprocessing import Queue from torch.multiprocessing import Queue
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
from lerobot.common.cameras import so100_follower_end_effector # noqa: F401
from lerobot.common.constants import ( from lerobot.common.constants import (
CHECKPOINTS_DIR, CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK, LAST_CHECKPOINT_LINK,
@@ -101,6 +102,8 @@ from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.robots import so100_follower_end_effector # noqa: F401
from lerobot.common.teleoperators import gamepad, so100_leader # noqa: F401
from lerobot.common.utils.random_utils import set_seed from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.train_utils import ( from lerobot.common.utils.train_utils import (
get_step_checkpoint_dir, get_step_checkpoint_dir,