General fixes to abide by the new config in learner_server, actor_server, gym_manipulator
This commit is contained in:
committed by
AdilZouitine
parent
df96e5b3b2
commit
1edfbf792a
@@ -57,7 +57,7 @@ class AlohaEnv(EnvConfig):
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"agent_pos": OBS_ROBOT,
|
||||
"agent_pos": OBS_STATE,
|
||||
"top": f"{OBS_IMAGE}.top",
|
||||
"pixels/top": f"{OBS_IMAGES}.top",
|
||||
}
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from lerobot.common.model.kinematics_utils import RobotKinematics
|
||||
|
||||
__all__ = ["RobotKinematics"]
|
||||
@@ -1,4 +1,3 @@
|
||||
from .config import RobotConfig
|
||||
from .robot import Robot
|
||||
from .robot_wrapper import RobotWrapper
|
||||
from .utils import make_robot_from_config
|
||||
|
||||
@@ -66,7 +66,7 @@ class SO100FollowerEndEffector(SO100Follower):
|
||||
self.kinematics = RobotKinematics(robot_type="so101")
|
||||
|
||||
# Set the forward kinematics function
|
||||
self.fk_function = self.kinematics.fk_gripper_tip
|
||||
self.fk_function = self.kinematics.fk_gripper
|
||||
|
||||
# Store the bounds for end-effector position
|
||||
self.end_effector_bounds = self.config.end_effector_bounds
|
||||
@@ -152,16 +152,16 @@ class SO100FollowerEndEffector(SO100Follower):
|
||||
fk_func=self.fk_function,
|
||||
)
|
||||
|
||||
target_joint_values_in_degrees = np.clip(target_joint_values_in_degrees, -180.0, 180.0)
|
||||
# Create joint space action dictionary
|
||||
joint_action = {
|
||||
f"{key}.pos": target_joint_values_in_degrees[i]
|
||||
for i, key in enumerate(self.bus.motors.keys())
|
||||
f"{key}.pos": target_joint_values_in_degrees[i] for i, key in enumerate(self.bus.motors.keys())
|
||||
}
|
||||
|
||||
# Handle gripper separately if included in action
|
||||
joint_action["gripper.pos"] = np.clip(
|
||||
self.current_joint_pos[-1] + (action[-1] - 1) * self.config.max_gripper_pos,
|
||||
0,
|
||||
5,
|
||||
self.config.max_gripper_pos,
|
||||
)
|
||||
|
||||
@@ -191,3 +191,7 @@ class SO100FollowerEndEffector(SO100Follower):
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
def reset(self):
|
||||
self.current_ee_pos = None
|
||||
self.current_joint_pos = None
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode
|
||||
@@ -28,7 +27,6 @@ from lerobot.common.motors.feetech import (
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_so101_leader import SO101LeaderConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -141,4 +139,4 @@ class SO101Leader(Teleoperator):
|
||||
DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.bus.disconnect()
|
||||
logger.info(f"{self} disconnected.")
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
@@ -80,10 +80,13 @@ import torch
|
||||
from torch import nn
|
||||
from torch.multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.common.cameras import opencv # noqa: F401
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
from lerobot.common.robots import so100_follower_end_effector # noqa: F401
|
||||
from lerobot.common.teleoperators import gamepad, so100_leader # noqa: F401
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
from lerobot.common.utils.robot_utils import busy_wait
|
||||
from lerobot.common.utils.utils import (
|
||||
TimerManager,
|
||||
get_safe_torch_device,
|
||||
|
||||
@@ -25,7 +25,7 @@ import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F # noqa: N812
|
||||
|
||||
from lerobot.common.cameras import intel, opencv # noqa: F401
|
||||
from lerobot.common.cameras import opencv # noqa: F401
|
||||
from lerobot.common.envs.configs import EnvConfig
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.model.kinematics import RobotKinematics
|
||||
@@ -37,9 +37,7 @@ from lerobot.common.robots import ( # noqa: F401
|
||||
from lerobot.common.teleoperators import (
|
||||
gamepad, # noqa: F401
|
||||
make_teleoperator_from_config,
|
||||
so101_leader,
|
||||
)
|
||||
from lerobot.common.teleoperators.gamepad.configuration_gamepad import GamepadTeleopConfig
|
||||
from lerobot.common.teleoperators.gamepad.teleop_gamepad import GamepadTeleop
|
||||
from lerobot.common.utils.robot_utils import busy_wait
|
||||
from lerobot.common.utils.utils import log_say
|
||||
@@ -307,6 +305,8 @@ class RobotEnv(gym.Env):
|
||||
"""
|
||||
super().reset(seed=seed, options=options)
|
||||
|
||||
self.robot.reset()
|
||||
|
||||
# Capture the initial observation.
|
||||
observation = self._get_observation()
|
||||
|
||||
@@ -1003,7 +1003,9 @@ class GripperActionWrapper(gym.ActionWrapper):
|
||||
|
||||
gripper_state = self.unwrapped.robot.bus.sync_read("Present_Position")["gripper"]
|
||||
|
||||
gripper_action_value = np.clip(gripper_state + gripper_command, 0, self.unwrapped.robot.config.max_gripper_pos)
|
||||
gripper_action_value = np.clip(
|
||||
gripper_state + gripper_command, 0, self.unwrapped.robot.config.max_gripper_pos
|
||||
)
|
||||
action[-1] = gripper_action_value.item()
|
||||
return action
|
||||
|
||||
@@ -1138,7 +1140,7 @@ class EEObservationWrapper(gym.ObservationWrapper):
|
||||
)
|
||||
|
||||
# Initialize kinematics instance for the appropriate robot type
|
||||
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100")
|
||||
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so101")
|
||||
self.kinematics = RobotKinematics(robot_type)
|
||||
self.fk_function = self.kinematics.fk_gripper_tip
|
||||
|
||||
@@ -1152,16 +1154,10 @@ class EEObservationWrapper(gym.ObservationWrapper):
|
||||
Returns:
|
||||
Enhanced observation with end-effector pose information.
|
||||
"""
|
||||
current_joint_pos = self.unwrapped._get_observation()["observation.state"]
|
||||
current_joint_pos = self.unwrapped._get_observation()["agent_pos"]
|
||||
|
||||
current_ee_pos = self.fk_function(current_joint_pos)
|
||||
observation["observation.state"] = torch.cat(
|
||||
[
|
||||
observation["observation.state"],
|
||||
torch.from_numpy(current_ee_pos[:3, 3]),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
observation["agent_pos"] = np.concatenate([observation["agent_pos"], current_ee_pos[:3, 3]], -1)
|
||||
return observation
|
||||
|
||||
|
||||
@@ -1178,9 +1174,7 @@ class BaseLeaderControlWrapper(gym.Wrapper):
|
||||
where the human can control a leader robot to guide the follower robot's movements.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, env, teleop_device, use_geared_leader_arm: bool = False, use_gripper=False
|
||||
):
|
||||
def __init__(self, env, teleop_device, use_geared_leader_arm: bool = False, use_gripper=False):
|
||||
"""
|
||||
Initialize the base leader control wrapper.
|
||||
|
||||
@@ -1322,10 +1316,12 @@ class BaseLeaderControlWrapper(gym.Wrapper):
|
||||
|
||||
if self.use_gripper:
|
||||
if self.prev_leader_gripper is None:
|
||||
self.prev_leader_gripper = np.clip(leader_pos[-1], 0, self.robot_follower.config.max_gripper_pos)
|
||||
|
||||
self.prev_leader_gripper = np.clip(
|
||||
leader_pos[-1], 0, self.robot_follower.config.max_gripper_pos
|
||||
)
|
||||
|
||||
# Get gripper action delta based on leader pose
|
||||
leader_gripper = leader_pos[-1]
|
||||
leader_gripper = leader_pos[-1]
|
||||
# follower_gripper = follower_pos[-1]
|
||||
gripper_delta = leader_gripper - self.prev_leader_gripper
|
||||
|
||||
@@ -1342,7 +1338,7 @@ class BaseLeaderControlWrapper(gym.Wrapper):
|
||||
action = np.append(action, gripper_action)
|
||||
# action_intervention = np.append(action_intervention, gripper_delta)
|
||||
|
||||
return action # , action_intervention
|
||||
return action # , action_intervention
|
||||
|
||||
def _handle_leader_teleoperation(self):
|
||||
"""
|
||||
@@ -1389,7 +1385,11 @@ class BaseLeaderControlWrapper(gym.Wrapper):
|
||||
info["is_intervention"] = is_intervention
|
||||
info["action_intervention"] = action if is_intervention else None
|
||||
|
||||
self.prev_leader_gripper = np.clip(self.robot_leader.bus.sync_read("Present_Position")["gripper"], 0, self.robot_follower.config.max_gripper_pos)
|
||||
self.prev_leader_gripper = np.clip(
|
||||
self.robot_leader.bus.sync_read("Present_Position")["gripper"],
|
||||
0,
|
||||
self.robot_follower.config.max_gripper_pos,
|
||||
)
|
||||
|
||||
# Check for success or manual termination
|
||||
success = self.keyboard_events["episode_success"]
|
||||
@@ -1569,7 +1569,7 @@ class GearedLeaderAutomaticControlWrapper(BaseLeaderControlWrapper):
|
||||
self.ee_error_over_time_queue[-1] - self.previous_ee_error_over_time_over_time
|
||||
)
|
||||
self.previous_ee_error_over_time_over_time = self.ee_error_over_time_queue[-1]
|
||||
|
||||
|
||||
self.intervention_threshold = 0.02
|
||||
|
||||
# Determine if intervention should start or stop based on the thresholds set in the constructor
|
||||
@@ -1890,8 +1890,8 @@ def make_robot_env(cfg: EnvConfig) -> gym.Env:
|
||||
env = AddJointVelocityToObservation(env=env, fps=cfg.fps)
|
||||
if cfg.wrapper.add_current_to_observation:
|
||||
env = AddCurrentToObservation(env=env)
|
||||
if False and cfg.wrapper.add_ee_pose_to_observation:
|
||||
env = EEObservationWrapper(env=env, ee_pose_limits=cfg.wrapper.ee_action_space_params.bounds)
|
||||
if cfg.wrapper.add_ee_pose_to_observation:
|
||||
env = EEObservationWrapper(env=env, ee_pose_limits=robot.end_effector_bounds)
|
||||
|
||||
env = ConvertToLeRobotObservation(env=env, device=cfg.device)
|
||||
|
||||
@@ -1910,9 +1910,9 @@ def make_robot_env(cfg: EnvConfig) -> gym.Env:
|
||||
if cfg.wrapper:
|
||||
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
||||
if cfg.wrapper.use_gripper:
|
||||
env = GripperActionWrapper(
|
||||
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
|
||||
)
|
||||
# env = GripperActionWrapper(
|
||||
# env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
|
||||
# )
|
||||
if cfg.wrapper.gripper_penalty is not None:
|
||||
env = GripperPenaltyWrapper(
|
||||
env=env,
|
||||
@@ -1922,7 +1922,9 @@ def make_robot_env(cfg: EnvConfig) -> gym.Env:
|
||||
# Control mode specific wrappers
|
||||
control_mode = cfg.wrapper.control_mode
|
||||
if control_mode == "gamepad":
|
||||
assert isinstance(teleop_device, GamepadTeleop), "teleop_device must be an instance of GamepadTeleop for gamepad control mode"
|
||||
assert isinstance(teleop_device, GamepadTeleop), (
|
||||
"teleop_device must be an instance of GamepadTeleop for gamepad control mode"
|
||||
)
|
||||
env = GamepadControlWrapper(
|
||||
env=env,
|
||||
teleop_device=teleop_device,
|
||||
@@ -2117,11 +2119,10 @@ def record_dataset(env, policy, cfg):
|
||||
really_done = success_steps_collected >= cfg.number_of_steps_after_success
|
||||
|
||||
frame["next.done"] = np.array([really_done], dtype=bool)
|
||||
frame["task"] = cfg.task
|
||||
frame["complementary_info.discrete_penalty"] = torch.tensor(
|
||||
[info.get("discrete_penalty", 0.0)], dtype=torch.float32
|
||||
)
|
||||
dataset.add_frame(frame)
|
||||
dataset.add_frame(frame, task=cfg.task)
|
||||
|
||||
# Maintain consistent timing
|
||||
if cfg.fps:
|
||||
@@ -2233,7 +2234,7 @@ def main(cfg: EnvConfig):
|
||||
while num_episode < 10:
|
||||
start_loop_s = time.perf_counter()
|
||||
# Sample a new random action from the robot's action space.
|
||||
new_random_action = env.action_space.sample()
|
||||
new_random_action = env.action_space.sample() * 0.0
|
||||
# Update the smoothed action using an exponential moving average.
|
||||
smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
|
||||
|
||||
|
||||
@@ -91,6 +91,7 @@ from torch import nn
|
||||
from torch.multiprocessing import Queue
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from lerobot.common.cameras import so100_follower_end_effector # noqa: F401
|
||||
from lerobot.common.constants import (
|
||||
CHECKPOINTS_DIR,
|
||||
LAST_CHECKPOINT_LINK,
|
||||
@@ -101,6 +102,8 @@ from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.robots import so100_follower_end_effector # noqa: F401
|
||||
from lerobot.common.teleoperators import gamepad, so100_leader # noqa: F401
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
from lerobot.common.utils.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
|
||||
Reference in New Issue
Block a user