fixed naming convention in gym_manipulator, adapted get observation to so100_follower_end_effector

This commit is contained in:
Michel Aractingi
2025-05-23 17:58:40 +02:00
committed by AdilZouitine
parent 2475645f5f
commit 2f62e5496e
3 changed files with 59 additions and 38 deletions

View File

@@ -18,6 +18,7 @@ import logging
from typing import Any, Dict from typing import Any, Dict
import numpy as np import numpy as np
import time
from lerobot.common.errors import DeviceNotConnectedError from lerobot.common.errors import DeviceNotConnectedError
from lerobot.common.model.kinematics import RobotKinematics from lerobot.common.model.kinematics import RobotKinematics
@@ -26,6 +27,7 @@ from lerobot.common.motors.feetech import FeetechMotorsBus
from ..so100_follower import SO100Follower from ..so100_follower import SO100Follower
from .config_so100_follower_end_effector import SO100FollowerEndEffectorConfig from .config_so100_follower_end_effector import SO100FollowerEndEffectorConfig
from lerobot.common.cameras import make_cameras_from_configs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -56,6 +58,8 @@ class SO100FollowerEndEffector(SO100Follower):
calibration=self.calibration, calibration=self.calibration,
) )
self.cameras = make_cameras_from_configs(config.cameras)
self.config = config self.config = config
# Initialize the kinematics module for the so100 robot # Initialize the kinematics module for the so100 robot
@@ -164,3 +168,24 @@ class SO100FollowerEndEffector(SO100Follower):
) )
# Send joint space action to parent class # Send joint space action to parent class
return super().send_action(joint_action) return super().send_action(joint_action)
def get_observation(self) -> dict[str, Any]:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Read arm position
start = time.perf_counter()
obs_dict = self.bus.sync_read("Present_Position")
obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()}
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[cam_key] = cam.async_read()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
return obs_dict

View File

@@ -220,15 +220,22 @@ class RobotEnv(gym.Env):
self.current_step = 0 self.current_step = 0
self.episode_data = None self.episode_data = None
self._joint_names = [f"{key}.pos" for key in self.robot.bus.motors.keys()]
self._image_keys = self.robot.cameras.keys()
# Read initial joint positions using the bus # Read initial joint positions using the bus
self.current_joint_positions = self._get_observation() self.current_joint_positions = self._get_observation()["agent_pos"]
self._setup_spaces() self._setup_spaces()
def _get_observation(self) -> np.ndarray: def _get_observation(self) -> np.ndarray:
"""Helper to convert a dictionary from bus.sync_read to an ordered numpy array.""" """Helper to convert a dictionary from bus.sync_read to an ordered numpy array."""
joint_positions_dict = self.robot.bus.sync_read("Present_Position") obs_dict = self.robot.get_observation()
return np.array([joint_positions_dict[name] for name in joint_positions_dict.keys()], dtype=np.float32) joint_positions = np.array([obs_dict[name] for name in self._joint_names], dtype=np.float32)
images = {key: obs_dict[key] for key in self._image_keys}
return {"agent_pos": joint_positions, "pixels": images}
def _setup_spaces(self): def _setup_spaces(self):
""" """
@@ -244,16 +251,20 @@ class RobotEnv(gym.Env):
""" """
example_obs = self._get_observation() example_obs = self._get_observation()
observation_spaces = {}
# Define observation spaces for images and other states. # Define observation spaces for images and other states.
image_keys = [key for key in example_obs if "image" in key] if "pixels" in example_obs:
observation_spaces = { prefix = "observation.images" if len(example_obs["pixels"]) > 1 else "observation.image"
key: gym.spaces.Box(low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8) observation_spaces = {
for key in image_keys f"{prefix}.{key}": gym.spaces.Box(low=0, high=255, shape=example_obs["pixels"][key].shape, dtype=np.uint8)
} for key in example_obs["pixels"]
}
observation_spaces["observation.state"] = gym.spaces.Box( observation_spaces["observation.state"] = gym.spaces.Box(
low=0, low=0,
high=10, high=10,
shape=example_obs["observation.state"].shape, shape=example_obs["agent_pos"].shape,
dtype=np.float32, dtype=np.float32,
) )
@@ -315,10 +326,10 @@ class RobotEnv(gym.Env):
- truncated (bool): True if the episode was truncated (e.g., time constraints). - truncated (bool): True if the episode was truncated (e.g., time constraints).
- info (dict): Additional debugging information including intervention status. - info (dict): Additional debugging information including intervention status.
""" """
self.current_joint_positions = self._get_observation() self.current_joint_positions = self._get_observation()["observation.state"]
self.robot.send_action(torch.from_numpy(action)) self.robot.send_action(torch.from_numpy(action))
observation = self._get_observation() observation = self._get_observation()["observation.state"]
if self.display_cameras: if self.display_cameras:
self.render() self.render()
@@ -412,11 +423,9 @@ class AddJointVelocityToObservation(gym.ObservationWrapper):
Returns: Returns:
The modified observation with joint velocities. The modified observation with joint velocities.
""" """
joint_velocities = (observation["observation.state"] - self.last_joint_positions) / self.dt joint_velocities = (observation["agent_pos"] - self.last_joint_positions) / self.dt
self.last_joint_positions = observation["observation.state"].clone() self.last_joint_positions = observation["agent_pos"]
observation["observation.state"] = torch.cat( observation["agent_pos"] = np.concatenate([observation["agent_pos"], joint_velocities], axis=-1)
[observation["observation.state"], joint_velocities], dim=-1
)
return observation return observation
@@ -466,12 +475,8 @@ class AddCurrentToObservation(gym.ObservationWrapper):
Returns: Returns:
The modified observation with current values. The modified observation with current values.
""" """
present_current_dict = self.unwrapped.robot.bus.sync_read("Present_Current") present_current_observation = self.unwrapped._get_observation()["agent_pos"]
present_current_observation = torch.tensor([present_current_dict[name] for name in present_current_dict.keys()], dtype=np.float32) observation["agent_pos"] = np.concatenate([observation["agent_pos"], present_current_observation], axis=-1)
observation["observation.state"] = torch.cat(
[observation["observation.state"], present_current_observation], dim=-1
)
return observation return observation
@@ -740,16 +745,8 @@ class ConvertToLeRobotObservation(gym.ObservationWrapper):
Returns: Returns:
The processed observation with normalized images and proper tensor formats. The processed observation with normalized images and proper tensor formats.
""" """
for key in observation: observation = preprocess_observation(observation)
observation[key] = observation[key].float() observation = {key: observation[key].to(self.device, non_blocking=self.device.type == "cuda") for key in observation}
if "image" in key:
observation[key] = observation[key].permute(2, 0, 1)
observation[key] /= 255.0
observation = {
key: observation[key].to(self.device, non_blocking=self.device.type == "cuda")
for key in observation
}
return observation return observation
@@ -1078,7 +1075,7 @@ class EEActionWrapper(gym.ActionWrapper):
gripper_command = action[-1] gripper_command = action[-1]
action = action[:-1] action = action[:-1]
current_joint_pos = self.unwrapped._get_observation() current_joint_pos = self.unwrapped._get_observation()["observation.state"]
current_ee_pos = self.fk_function(current_joint_pos) current_ee_pos = self.fk_function(current_joint_pos)
desired_ee_pos[:3, 3] = np.clip( desired_ee_pos[:3, 3] = np.clip(
@@ -1141,7 +1138,7 @@ class EEObservationWrapper(gym.ObservationWrapper):
Returns: Returns:
Enhanced observation with end-effector pose information. Enhanced observation with end-effector pose information.
""" """
current_joint_pos = self.unwrapped._get_observation() current_joint_pos = self.unwrapped._get_observation()["observation.state"]
current_ee_pos = self.fk_function(current_joint_pos) current_ee_pos = self.fk_function(current_joint_pos)
observation["observation.state"] = torch.cat( observation["observation.state"] = torch.cat(
@@ -1881,9 +1878,7 @@ def make_robot_env(cfg: EnvConfig) -> gym.Env:
env = AddJointVelocityToObservation(env=env, fps=cfg.fps) env = AddJointVelocityToObservation(env=env, fps=cfg.fps)
if cfg.wrapper.add_current_to_observation: if cfg.wrapper.add_current_to_observation:
env = AddCurrentToObservation(env=env) env = AddCurrentToObservation(env=env)
if cfg.wrapper.add_ee_pose_to_observation: if False and cfg.wrapper.add_ee_pose_to_observation:
if cfg.wrapper.ee_action_space_params is None or cfg.wrapper.ee_action_space_params.bounds is None:
raise ValueError("EEActionSpaceConfig with bounds must be provided for EEObservationWrapper.")
env = EEObservationWrapper(env=env, ee_pose_limits=cfg.wrapper.ee_action_space_params.bounds) env = EEObservationWrapper(env=env, ee_pose_limits=cfg.wrapper.ee_action_space_params.bounds)
env = ConvertToLeRobotObservation(env=env, device=cfg.device) env = ConvertToLeRobotObservation(env=env, device=cfg.device)
@@ -1917,7 +1912,7 @@ def make_robot_env(cfg: EnvConfig) -> gym.Env:
# ) # )
# Control mode specific wrappers # Control mode specific wrappers
control_mode = cfg.wrapper.ee_action_space_params.control_mode control_mode = cfg.wrapper.control_mode
if control_mode == "gamepad": if control_mode == "gamepad":
if teleop_device is None: if teleop_device is None:
raise ValueError("A teleop_device must be instantiated for gamepad control mode.") raise ValueError("A teleop_device must be instantiated for gamepad control mode.")

View File

@@ -91,6 +91,7 @@ def teleop_loop(
if isinstance(val, float): if isinstance(val, float):
rr.log(f"action_{act}", rr.Scalar(val)) rr.log(f"action_{act}", rr.Scalar(val))
breakpoint()
robot.send_action(action) robot.send_action(action)
loop_s = time.perf_counter() - loop_start loop_s = time.perf_counter() - loop_start