General fixes to abide by the new config in learner_server, actor_server, gym_manipulator

This commit is contained in:
Michel Aractingi
2025-05-27 15:49:33 +02:00
committed by AdilZouitine
parent df96e5b3b2
commit 1edfbf792a
8 changed files with 49 additions and 44 deletions

View File

@@ -57,7 +57,7 @@ class AlohaEnv(EnvConfig):
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"agent_pos": OBS_ROBOT,
"agent_pos": OBS_STATE,
"top": f"{OBS_IMAGE}.top",
"pixels/top": f"{OBS_IMAGES}.top",
}

View File

@@ -1,3 +0,0 @@
from lerobot.common.model.kinematics_utils import RobotKinematics
__all__ = ["RobotKinematics"]

View File

@@ -1,4 +1,3 @@
from .config import RobotConfig
from .robot import Robot
from .robot_wrapper import RobotWrapper
from .utils import make_robot_from_config

View File

@@ -66,7 +66,7 @@ class SO100FollowerEndEffector(SO100Follower):
self.kinematics = RobotKinematics(robot_type="so101")
# Set the forward kinematics function
self.fk_function = self.kinematics.fk_gripper_tip
self.fk_function = self.kinematics.fk_gripper
# Store the bounds for end-effector position
self.end_effector_bounds = self.config.end_effector_bounds
@@ -152,16 +152,16 @@ class SO100FollowerEndEffector(SO100Follower):
fk_func=self.fk_function,
)
target_joint_values_in_degrees = np.clip(target_joint_values_in_degrees, -180.0, 180.0)
# Create joint space action dictionary
joint_action = {
f"{key}.pos": target_joint_values_in_degrees[i]
for i, key in enumerate(self.bus.motors.keys())
f"{key}.pos": target_joint_values_in_degrees[i] for i, key in enumerate(self.bus.motors.keys())
}
# Handle gripper separately if included in action
joint_action["gripper.pos"] = np.clip(
self.current_joint_pos[-1] + (action[-1] - 1) * self.config.max_gripper_pos,
0,
5,
self.config.max_gripper_pos,
)
@@ -191,3 +191,7 @@ class SO100FollowerEndEffector(SO100Follower):
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
return obs_dict
def reset(self):
self.current_ee_pos = None
self.current_joint_pos = None

View File

@@ -16,7 +16,6 @@
import logging
import time
from typing import Any
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode
@@ -28,7 +27,6 @@ from lerobot.common.motors.feetech import (
from ..teleoperator import Teleoperator
from .config_so101_leader import SO101LeaderConfig
logger = logging.getLogger(__name__)
@@ -141,4 +139,4 @@ class SO101Leader(Teleoperator):
DeviceNotConnectedError(f"{self} is not connected.")
self.bus.disconnect()
logger.info(f"{self} disconnected.")
logger.info(f"{self} disconnected.")

View File

@@ -80,10 +80,13 @@ import torch
from torch import nn
from torch.multiprocessing import Event, Queue
from lerobot.common.cameras import opencv # noqa: F401
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.robots import so100_follower_end_effector # noqa: F401
from lerobot.common.teleoperators import gamepad, so100_leader # noqa: F401
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.robot_utils import busy_wait
from lerobot.common.utils.utils import (
TimerManager,
get_safe_torch_device,

View File

@@ -25,7 +25,7 @@ import numpy as np
import torch
import torchvision.transforms.functional as F # noqa: N812
from lerobot.common.cameras import intel, opencv # noqa: F401
from lerobot.common.cameras import opencv # noqa: F401
from lerobot.common.envs.configs import EnvConfig
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.model.kinematics import RobotKinematics
@@ -37,9 +37,7 @@ from lerobot.common.robots import ( # noqa: F401
from lerobot.common.teleoperators import (
gamepad, # noqa: F401
make_teleoperator_from_config,
so101_leader,
)
from lerobot.common.teleoperators.gamepad.configuration_gamepad import GamepadTeleopConfig
from lerobot.common.teleoperators.gamepad.teleop_gamepad import GamepadTeleop
from lerobot.common.utils.robot_utils import busy_wait
from lerobot.common.utils.utils import log_say
@@ -307,6 +305,8 @@ class RobotEnv(gym.Env):
"""
super().reset(seed=seed, options=options)
self.robot.reset()
# Capture the initial observation.
observation = self._get_observation()
@@ -1003,7 +1003,9 @@ class GripperActionWrapper(gym.ActionWrapper):
gripper_state = self.unwrapped.robot.bus.sync_read("Present_Position")["gripper"]
gripper_action_value = np.clip(gripper_state + gripper_command, 0, self.unwrapped.robot.config.max_gripper_pos)
gripper_action_value = np.clip(
gripper_state + gripper_command, 0, self.unwrapped.robot.config.max_gripper_pos
)
action[-1] = gripper_action_value.item()
return action
@@ -1138,7 +1140,7 @@ class EEObservationWrapper(gym.ObservationWrapper):
)
# Initialize kinematics instance for the appropriate robot type
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100")
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so101")
self.kinematics = RobotKinematics(robot_type)
self.fk_function = self.kinematics.fk_gripper_tip
@@ -1152,16 +1154,10 @@ class EEObservationWrapper(gym.ObservationWrapper):
Returns:
Enhanced observation with end-effector pose information.
"""
current_joint_pos = self.unwrapped._get_observation()["observation.state"]
current_joint_pos = self.unwrapped._get_observation()["agent_pos"]
current_ee_pos = self.fk_function(current_joint_pos)
observation["observation.state"] = torch.cat(
[
observation["observation.state"],
torch.from_numpy(current_ee_pos[:3, 3]),
],
dim=-1,
)
observation["agent_pos"] = np.concatenate([observation["agent_pos"], current_ee_pos[:3, 3]], -1)
return observation
@@ -1178,9 +1174,7 @@ class BaseLeaderControlWrapper(gym.Wrapper):
where the human can control a leader robot to guide the follower robot's movements.
"""
def __init__(
self, env, teleop_device, use_geared_leader_arm: bool = False, use_gripper=False
):
def __init__(self, env, teleop_device, use_geared_leader_arm: bool = False, use_gripper=False):
"""
Initialize the base leader control wrapper.
@@ -1322,10 +1316,12 @@ class BaseLeaderControlWrapper(gym.Wrapper):
if self.use_gripper:
if self.prev_leader_gripper is None:
self.prev_leader_gripper = np.clip(leader_pos[-1], 0, self.robot_follower.config.max_gripper_pos)
self.prev_leader_gripper = np.clip(
leader_pos[-1], 0, self.robot_follower.config.max_gripper_pos
)
# Get gripper action delta based on leader pose
leader_gripper = leader_pos[-1]
leader_gripper = leader_pos[-1]
# follower_gripper = follower_pos[-1]
gripper_delta = leader_gripper - self.prev_leader_gripper
@@ -1342,7 +1338,7 @@ class BaseLeaderControlWrapper(gym.Wrapper):
action = np.append(action, gripper_action)
# action_intervention = np.append(action_intervention, gripper_delta)
return action # , action_intervention
return action # , action_intervention
def _handle_leader_teleoperation(self):
"""
@@ -1389,7 +1385,11 @@ class BaseLeaderControlWrapper(gym.Wrapper):
info["is_intervention"] = is_intervention
info["action_intervention"] = action if is_intervention else None
self.prev_leader_gripper = np.clip(self.robot_leader.bus.sync_read("Present_Position")["gripper"], 0, self.robot_follower.config.max_gripper_pos)
self.prev_leader_gripper = np.clip(
self.robot_leader.bus.sync_read("Present_Position")["gripper"],
0,
self.robot_follower.config.max_gripper_pos,
)
# Check for success or manual termination
success = self.keyboard_events["episode_success"]
@@ -1569,7 +1569,7 @@ class GearedLeaderAutomaticControlWrapper(BaseLeaderControlWrapper):
self.ee_error_over_time_queue[-1] - self.previous_ee_error_over_time_over_time
)
self.previous_ee_error_over_time_over_time = self.ee_error_over_time_queue[-1]
self.intervention_threshold = 0.02
# Determine if intervention should start or stop based on the thresholds set in the constructor
@@ -1890,8 +1890,8 @@ def make_robot_env(cfg: EnvConfig) -> gym.Env:
env = AddJointVelocityToObservation(env=env, fps=cfg.fps)
if cfg.wrapper.add_current_to_observation:
env = AddCurrentToObservation(env=env)
if False and cfg.wrapper.add_ee_pose_to_observation:
env = EEObservationWrapper(env=env, ee_pose_limits=cfg.wrapper.ee_action_space_params.bounds)
if cfg.wrapper.add_ee_pose_to_observation:
env = EEObservationWrapper(env=env, ee_pose_limits=robot.end_effector_bounds)
env = ConvertToLeRobotObservation(env=env, device=cfg.device)
@@ -1910,9 +1910,9 @@ def make_robot_env(cfg: EnvConfig) -> gym.Env:
if cfg.wrapper:
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
if cfg.wrapper.use_gripper:
env = GripperActionWrapper(
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
)
# env = GripperActionWrapper(
# env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
# )
if cfg.wrapper.gripper_penalty is not None:
env = GripperPenaltyWrapper(
env=env,
@@ -1922,7 +1922,9 @@ def make_robot_env(cfg: EnvConfig) -> gym.Env:
# Control mode specific wrappers
control_mode = cfg.wrapper.control_mode
if control_mode == "gamepad":
assert isinstance(teleop_device, GamepadTeleop), "teleop_device must be an instance of GamepadTeleop for gamepad control mode"
assert isinstance(teleop_device, GamepadTeleop), (
"teleop_device must be an instance of GamepadTeleop for gamepad control mode"
)
env = GamepadControlWrapper(
env=env,
teleop_device=teleop_device,
@@ -2117,11 +2119,10 @@ def record_dataset(env, policy, cfg):
really_done = success_steps_collected >= cfg.number_of_steps_after_success
frame["next.done"] = np.array([really_done], dtype=bool)
frame["task"] = cfg.task
frame["complementary_info.discrete_penalty"] = torch.tensor(
[info.get("discrete_penalty", 0.0)], dtype=torch.float32
)
dataset.add_frame(frame)
dataset.add_frame(frame, task=cfg.task)
# Maintain consistent timing
if cfg.fps:
@@ -2233,7 +2234,7 @@ def main(cfg: EnvConfig):
while num_episode < 10:
start_loop_s = time.perf_counter()
# Sample a new random action from the robot's action space.
new_random_action = env.action_space.sample()
new_random_action = env.action_space.sample() * 0.0
# Update the smoothed action using an exponential moving average.
smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action

View File

@@ -91,6 +91,7 @@ from torch import nn
from torch.multiprocessing import Queue
from torch.optim.optimizer import Optimizer
from lerobot.common.cameras import so100_follower_end_effector # noqa: F401
from lerobot.common.constants import (
CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK,
@@ -101,6 +102,8 @@ from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.robots import so100_follower_end_effector # noqa: F401
from lerobot.common.teleoperators import gamepad, so100_leader # noqa: F401
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.train_utils import (
get_step_checkpoint_dir,