fixed naming convention in gym_manipulator, adapted get observation to so100_follower_end_effector
This commit is contained in:
committed by
AdilZouitine
parent
2475645f5f
commit
2f62e5496e
@@ -18,6 +18,7 @@ import logging
|
|||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import time
|
||||||
|
|
||||||
from lerobot.common.errors import DeviceNotConnectedError
|
from lerobot.common.errors import DeviceNotConnectedError
|
||||||
from lerobot.common.model.kinematics import RobotKinematics
|
from lerobot.common.model.kinematics import RobotKinematics
|
||||||
@@ -26,6 +27,7 @@ from lerobot.common.motors.feetech import FeetechMotorsBus
|
|||||||
|
|
||||||
from ..so100_follower import SO100Follower
|
from ..so100_follower import SO100Follower
|
||||||
from .config_so100_follower_end_effector import SO100FollowerEndEffectorConfig
|
from .config_so100_follower_end_effector import SO100FollowerEndEffectorConfig
|
||||||
|
from lerobot.common.cameras import make_cameras_from_configs
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -56,6 +58,8 @@ class SO100FollowerEndEffector(SO100Follower):
|
|||||||
calibration=self.calibration,
|
calibration=self.calibration,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.cameras = make_cameras_from_configs(config.cameras)
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# Initialize the kinematics module for the so100 robot
|
# Initialize the kinematics module for the so100 robot
|
||||||
@@ -164,3 +168,24 @@ class SO100FollowerEndEffector(SO100Follower):
|
|||||||
)
|
)
|
||||||
# Send joint space action to parent class
|
# Send joint space action to parent class
|
||||||
return super().send_action(joint_action)
|
return super().send_action(joint_action)
|
||||||
|
|
||||||
|
|
||||||
|
def get_observation(self) -> dict[str, Any]:
|
||||||
|
if not self.is_connected:
|
||||||
|
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||||
|
|
||||||
|
# Read arm position
|
||||||
|
start = time.perf_counter()
|
||||||
|
obs_dict = self.bus.sync_read("Present_Position")
|
||||||
|
obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()}
|
||||||
|
dt_ms = (time.perf_counter() - start) * 1e3
|
||||||
|
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
|
||||||
|
|
||||||
|
# Capture images from cameras
|
||||||
|
for cam_key, cam in self.cameras.items():
|
||||||
|
start = time.perf_counter()
|
||||||
|
obs_dict[cam_key] = cam.async_read()
|
||||||
|
dt_ms = (time.perf_counter() - start) * 1e3
|
||||||
|
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||||
|
|
||||||
|
return obs_dict
|
||||||
|
|||||||
@@ -220,15 +220,22 @@ class RobotEnv(gym.Env):
|
|||||||
self.current_step = 0
|
self.current_step = 0
|
||||||
self.episode_data = None
|
self.episode_data = None
|
||||||
|
|
||||||
|
self._joint_names = [f"{key}.pos" for key in self.robot.bus.motors.keys()]
|
||||||
|
self._image_keys = self.robot.cameras.keys()
|
||||||
|
|
||||||
# Read initial joint positions using the bus
|
# Read initial joint positions using the bus
|
||||||
self.current_joint_positions = self._get_observation()
|
self.current_joint_positions = self._get_observation()["agent_pos"]
|
||||||
|
|
||||||
|
|
||||||
self._setup_spaces()
|
self._setup_spaces()
|
||||||
|
|
||||||
def _get_observation(self) -> np.ndarray:
|
def _get_observation(self) -> np.ndarray:
|
||||||
"""Helper to convert a dictionary from bus.sync_read to an ordered numpy array."""
|
"""Helper to convert a dictionary from bus.sync_read to an ordered numpy array."""
|
||||||
joint_positions_dict = self.robot.bus.sync_read("Present_Position")
|
obs_dict = self.robot.get_observation()
|
||||||
return np.array([joint_positions_dict[name] for name in joint_positions_dict.keys()], dtype=np.float32)
|
joint_positions = np.array([obs_dict[name] for name in self._joint_names], dtype=np.float32)
|
||||||
|
|
||||||
|
images = {key: obs_dict[key] for key in self._image_keys}
|
||||||
|
return {"agent_pos": joint_positions, "pixels": images}
|
||||||
|
|
||||||
def _setup_spaces(self):
|
def _setup_spaces(self):
|
||||||
"""
|
"""
|
||||||
@@ -244,16 +251,20 @@ class RobotEnv(gym.Env):
|
|||||||
"""
|
"""
|
||||||
example_obs = self._get_observation()
|
example_obs = self._get_observation()
|
||||||
|
|
||||||
|
observation_spaces = {}
|
||||||
|
|
||||||
# Define observation spaces for images and other states.
|
# Define observation spaces for images and other states.
|
||||||
image_keys = [key for key in example_obs if "image" in key]
|
if "pixels" in example_obs:
|
||||||
observation_spaces = {
|
prefix = "observation.images" if len(example_obs["pixels"]) > 1 else "observation.image"
|
||||||
key: gym.spaces.Box(low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8)
|
observation_spaces = {
|
||||||
for key in image_keys
|
f"{prefix}.{key}": gym.spaces.Box(low=0, high=255, shape=example_obs["pixels"][key].shape, dtype=np.uint8)
|
||||||
}
|
for key in example_obs["pixels"]
|
||||||
|
}
|
||||||
|
|
||||||
observation_spaces["observation.state"] = gym.spaces.Box(
|
observation_spaces["observation.state"] = gym.spaces.Box(
|
||||||
low=0,
|
low=0,
|
||||||
high=10,
|
high=10,
|
||||||
shape=example_obs["observation.state"].shape,
|
shape=example_obs["agent_pos"].shape,
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -315,10 +326,10 @@ class RobotEnv(gym.Env):
|
|||||||
- truncated (bool): True if the episode was truncated (e.g., time constraints).
|
- truncated (bool): True if the episode was truncated (e.g., time constraints).
|
||||||
- info (dict): Additional debugging information including intervention status.
|
- info (dict): Additional debugging information including intervention status.
|
||||||
"""
|
"""
|
||||||
self.current_joint_positions = self._get_observation()
|
self.current_joint_positions = self._get_observation()["observation.state"]
|
||||||
|
|
||||||
self.robot.send_action(torch.from_numpy(action))
|
self.robot.send_action(torch.from_numpy(action))
|
||||||
observation = self._get_observation()
|
observation = self._get_observation()["observation.state"]
|
||||||
|
|
||||||
if self.display_cameras:
|
if self.display_cameras:
|
||||||
self.render()
|
self.render()
|
||||||
@@ -412,11 +423,9 @@ class AddJointVelocityToObservation(gym.ObservationWrapper):
|
|||||||
Returns:
|
Returns:
|
||||||
The modified observation with joint velocities.
|
The modified observation with joint velocities.
|
||||||
"""
|
"""
|
||||||
joint_velocities = (observation["observation.state"] - self.last_joint_positions) / self.dt
|
joint_velocities = (observation["agent_pos"] - self.last_joint_positions) / self.dt
|
||||||
self.last_joint_positions = observation["observation.state"].clone()
|
self.last_joint_positions = observation["agent_pos"]
|
||||||
observation["observation.state"] = torch.cat(
|
observation["agent_pos"] = np.concatenate([observation["agent_pos"], joint_velocities], axis=-1)
|
||||||
[observation["observation.state"], joint_velocities], dim=-1
|
|
||||||
)
|
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
|
||||||
@@ -466,12 +475,8 @@ class AddCurrentToObservation(gym.ObservationWrapper):
|
|||||||
Returns:
|
Returns:
|
||||||
The modified observation with current values.
|
The modified observation with current values.
|
||||||
"""
|
"""
|
||||||
present_current_dict = self.unwrapped.robot.bus.sync_read("Present_Current")
|
present_current_observation = self.unwrapped._get_observation()["agent_pos"]
|
||||||
present_current_observation = torch.tensor([present_current_dict[name] for name in present_current_dict.keys()], dtype=np.float32)
|
observation["agent_pos"] = np.concatenate([observation["agent_pos"], present_current_observation], axis=-1)
|
||||||
|
|
||||||
observation["observation.state"] = torch.cat(
|
|
||||||
[observation["observation.state"], present_current_observation], dim=-1
|
|
||||||
)
|
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
|
||||||
@@ -740,16 +745,8 @@ class ConvertToLeRobotObservation(gym.ObservationWrapper):
|
|||||||
Returns:
|
Returns:
|
||||||
The processed observation with normalized images and proper tensor formats.
|
The processed observation with normalized images and proper tensor formats.
|
||||||
"""
|
"""
|
||||||
for key in observation:
|
observation = preprocess_observation(observation)
|
||||||
observation[key] = observation[key].float()
|
observation = {key: observation[key].to(self.device, non_blocking=self.device.type == "cuda") for key in observation}
|
||||||
if "image" in key:
|
|
||||||
observation[key] = observation[key].permute(2, 0, 1)
|
|
||||||
observation[key] /= 255.0
|
|
||||||
observation = {
|
|
||||||
key: observation[key].to(self.device, non_blocking=self.device.type == "cuda")
|
|
||||||
for key in observation
|
|
||||||
}
|
|
||||||
|
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
|
||||||
@@ -1078,7 +1075,7 @@ class EEActionWrapper(gym.ActionWrapper):
|
|||||||
gripper_command = action[-1]
|
gripper_command = action[-1]
|
||||||
action = action[:-1]
|
action = action[:-1]
|
||||||
|
|
||||||
current_joint_pos = self.unwrapped._get_observation()
|
current_joint_pos = self.unwrapped._get_observation()["observation.state"]
|
||||||
|
|
||||||
current_ee_pos = self.fk_function(current_joint_pos)
|
current_ee_pos = self.fk_function(current_joint_pos)
|
||||||
desired_ee_pos[:3, 3] = np.clip(
|
desired_ee_pos[:3, 3] = np.clip(
|
||||||
@@ -1141,7 +1138,7 @@ class EEObservationWrapper(gym.ObservationWrapper):
|
|||||||
Returns:
|
Returns:
|
||||||
Enhanced observation with end-effector pose information.
|
Enhanced observation with end-effector pose information.
|
||||||
"""
|
"""
|
||||||
current_joint_pos = self.unwrapped._get_observation()
|
current_joint_pos = self.unwrapped._get_observation()["observation.state"]
|
||||||
|
|
||||||
current_ee_pos = self.fk_function(current_joint_pos)
|
current_ee_pos = self.fk_function(current_joint_pos)
|
||||||
observation["observation.state"] = torch.cat(
|
observation["observation.state"] = torch.cat(
|
||||||
@@ -1881,9 +1878,7 @@ def make_robot_env(cfg: EnvConfig) -> gym.Env:
|
|||||||
env = AddJointVelocityToObservation(env=env, fps=cfg.fps)
|
env = AddJointVelocityToObservation(env=env, fps=cfg.fps)
|
||||||
if cfg.wrapper.add_current_to_observation:
|
if cfg.wrapper.add_current_to_observation:
|
||||||
env = AddCurrentToObservation(env=env)
|
env = AddCurrentToObservation(env=env)
|
||||||
if cfg.wrapper.add_ee_pose_to_observation:
|
if False and cfg.wrapper.add_ee_pose_to_observation:
|
||||||
if cfg.wrapper.ee_action_space_params is None or cfg.wrapper.ee_action_space_params.bounds is None:
|
|
||||||
raise ValueError("EEActionSpaceConfig with bounds must be provided for EEObservationWrapper.")
|
|
||||||
env = EEObservationWrapper(env=env, ee_pose_limits=cfg.wrapper.ee_action_space_params.bounds)
|
env = EEObservationWrapper(env=env, ee_pose_limits=cfg.wrapper.ee_action_space_params.bounds)
|
||||||
|
|
||||||
env = ConvertToLeRobotObservation(env=env, device=cfg.device)
|
env = ConvertToLeRobotObservation(env=env, device=cfg.device)
|
||||||
@@ -1917,7 +1912,7 @@ def make_robot_env(cfg: EnvConfig) -> gym.Env:
|
|||||||
# )
|
# )
|
||||||
|
|
||||||
# Control mode specific wrappers
|
# Control mode specific wrappers
|
||||||
control_mode = cfg.wrapper.ee_action_space_params.control_mode
|
control_mode = cfg.wrapper.control_mode
|
||||||
if control_mode == "gamepad":
|
if control_mode == "gamepad":
|
||||||
if teleop_device is None:
|
if teleop_device is None:
|
||||||
raise ValueError("A teleop_device must be instantiated for gamepad control mode.")
|
raise ValueError("A teleop_device must be instantiated for gamepad control mode.")
|
||||||
|
|||||||
@@ -91,6 +91,7 @@ def teleop_loop(
|
|||||||
if isinstance(val, float):
|
if isinstance(val, float):
|
||||||
rr.log(f"action_{act}", rr.Scalar(val))
|
rr.log(f"action_{act}", rr.Scalar(val))
|
||||||
|
|
||||||
|
breakpoint()
|
||||||
robot.send_action(action)
|
robot.send_action(action)
|
||||||
loop_s = time.perf_counter() - loop_start
|
loop_s = time.perf_counter() - loop_start
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user