forked from tangger/lerobot
General fixes to abide by the new config in learner_server, actor_server, gym_manipulator
This commit is contained in:
committed by
AdilZouitine
parent
df96e5b3b2
commit
1edfbf792a
@@ -57,7 +57,7 @@ class AlohaEnv(EnvConfig):
|
|||||||
features_map: dict[str, str] = field(
|
features_map: dict[str, str] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"action": ACTION,
|
"action": ACTION,
|
||||||
"agent_pos": OBS_ROBOT,
|
"agent_pos": OBS_STATE,
|
||||||
"top": f"{OBS_IMAGE}.top",
|
"top": f"{OBS_IMAGE}.top",
|
||||||
"pixels/top": f"{OBS_IMAGES}.top",
|
"pixels/top": f"{OBS_IMAGES}.top",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +0,0 @@
|
|||||||
from lerobot.common.model.kinematics_utils import RobotKinematics
|
|
||||||
|
|
||||||
__all__ = ["RobotKinematics"]
|
|
||||||
@@ -1,4 +1,3 @@
|
|||||||
from .config import RobotConfig
|
from .config import RobotConfig
|
||||||
from .robot import Robot
|
from .robot import Robot
|
||||||
from .robot_wrapper import RobotWrapper
|
|
||||||
from .utils import make_robot_from_config
|
from .utils import make_robot_from_config
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ class SO100FollowerEndEffector(SO100Follower):
|
|||||||
self.kinematics = RobotKinematics(robot_type="so101")
|
self.kinematics = RobotKinematics(robot_type="so101")
|
||||||
|
|
||||||
# Set the forward kinematics function
|
# Set the forward kinematics function
|
||||||
self.fk_function = self.kinematics.fk_gripper_tip
|
self.fk_function = self.kinematics.fk_gripper
|
||||||
|
|
||||||
# Store the bounds for end-effector position
|
# Store the bounds for end-effector position
|
||||||
self.end_effector_bounds = self.config.end_effector_bounds
|
self.end_effector_bounds = self.config.end_effector_bounds
|
||||||
@@ -152,16 +152,16 @@ class SO100FollowerEndEffector(SO100Follower):
|
|||||||
fk_func=self.fk_function,
|
fk_func=self.fk_function,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
target_joint_values_in_degrees = np.clip(target_joint_values_in_degrees, -180.0, 180.0)
|
||||||
# Create joint space action dictionary
|
# Create joint space action dictionary
|
||||||
joint_action = {
|
joint_action = {
|
||||||
f"{key}.pos": target_joint_values_in_degrees[i]
|
f"{key}.pos": target_joint_values_in_degrees[i] for i, key in enumerate(self.bus.motors.keys())
|
||||||
for i, key in enumerate(self.bus.motors.keys())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Handle gripper separately if included in action
|
# Handle gripper separately if included in action
|
||||||
joint_action["gripper.pos"] = np.clip(
|
joint_action["gripper.pos"] = np.clip(
|
||||||
self.current_joint_pos[-1] + (action[-1] - 1) * self.config.max_gripper_pos,
|
self.current_joint_pos[-1] + (action[-1] - 1) * self.config.max_gripper_pos,
|
||||||
0,
|
5,
|
||||||
self.config.max_gripper_pos,
|
self.config.max_gripper_pos,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -191,3 +191,7 @@ class SO100FollowerEndEffector(SO100Follower):
|
|||||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||||
|
|
||||||
return obs_dict
|
return obs_dict
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.current_ee_pos = None
|
||||||
|
self.current_joint_pos = None
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||||
from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode
|
from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode
|
||||||
@@ -28,7 +27,6 @@ from lerobot.common.motors.feetech import (
|
|||||||
from ..teleoperator import Teleoperator
|
from ..teleoperator import Teleoperator
|
||||||
from .config_so101_leader import SO101LeaderConfig
|
from .config_so101_leader import SO101LeaderConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -141,4 +139,4 @@ class SO101Leader(Teleoperator):
|
|||||||
DeviceNotConnectedError(f"{self} is not connected.")
|
DeviceNotConnectedError(f"{self} is not connected.")
|
||||||
|
|
||||||
self.bus.disconnect()
|
self.bus.disconnect()
|
||||||
logger.info(f"{self} disconnected.")
|
logger.info(f"{self} disconnected.")
|
||||||
|
|||||||
@@ -80,10 +80,13 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.multiprocessing import Event, Queue
|
from torch.multiprocessing import Event, Queue
|
||||||
|
|
||||||
|
from lerobot.common.cameras import opencv # noqa: F401
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||||
from lerobot.common.robot_devices.utils import busy_wait
|
from lerobot.common.robots import so100_follower_end_effector # noqa: F401
|
||||||
|
from lerobot.common.teleoperators import gamepad, so100_leader # noqa: F401
|
||||||
from lerobot.common.utils.random_utils import set_seed
|
from lerobot.common.utils.random_utils import set_seed
|
||||||
|
from lerobot.common.utils.robot_utils import busy_wait
|
||||||
from lerobot.common.utils.utils import (
|
from lerobot.common.utils.utils import (
|
||||||
TimerManager,
|
TimerManager,
|
||||||
get_safe_torch_device,
|
get_safe_torch_device,
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torchvision.transforms.functional as F # noqa: N812
|
import torchvision.transforms.functional as F # noqa: N812
|
||||||
|
|
||||||
from lerobot.common.cameras import intel, opencv # noqa: F401
|
from lerobot.common.cameras import opencv # noqa: F401
|
||||||
from lerobot.common.envs.configs import EnvConfig
|
from lerobot.common.envs.configs import EnvConfig
|
||||||
from lerobot.common.envs.utils import preprocess_observation
|
from lerobot.common.envs.utils import preprocess_observation
|
||||||
from lerobot.common.model.kinematics import RobotKinematics
|
from lerobot.common.model.kinematics import RobotKinematics
|
||||||
@@ -37,9 +37,7 @@ from lerobot.common.robots import ( # noqa: F401
|
|||||||
from lerobot.common.teleoperators import (
|
from lerobot.common.teleoperators import (
|
||||||
gamepad, # noqa: F401
|
gamepad, # noqa: F401
|
||||||
make_teleoperator_from_config,
|
make_teleoperator_from_config,
|
||||||
so101_leader,
|
|
||||||
)
|
)
|
||||||
from lerobot.common.teleoperators.gamepad.configuration_gamepad import GamepadTeleopConfig
|
|
||||||
from lerobot.common.teleoperators.gamepad.teleop_gamepad import GamepadTeleop
|
from lerobot.common.teleoperators.gamepad.teleop_gamepad import GamepadTeleop
|
||||||
from lerobot.common.utils.robot_utils import busy_wait
|
from lerobot.common.utils.robot_utils import busy_wait
|
||||||
from lerobot.common.utils.utils import log_say
|
from lerobot.common.utils.utils import log_say
|
||||||
@@ -307,6 +305,8 @@ class RobotEnv(gym.Env):
|
|||||||
"""
|
"""
|
||||||
super().reset(seed=seed, options=options)
|
super().reset(seed=seed, options=options)
|
||||||
|
|
||||||
|
self.robot.reset()
|
||||||
|
|
||||||
# Capture the initial observation.
|
# Capture the initial observation.
|
||||||
observation = self._get_observation()
|
observation = self._get_observation()
|
||||||
|
|
||||||
@@ -1003,7 +1003,9 @@ class GripperActionWrapper(gym.ActionWrapper):
|
|||||||
|
|
||||||
gripper_state = self.unwrapped.robot.bus.sync_read("Present_Position")["gripper"]
|
gripper_state = self.unwrapped.robot.bus.sync_read("Present_Position")["gripper"]
|
||||||
|
|
||||||
gripper_action_value = np.clip(gripper_state + gripper_command, 0, self.unwrapped.robot.config.max_gripper_pos)
|
gripper_action_value = np.clip(
|
||||||
|
gripper_state + gripper_command, 0, self.unwrapped.robot.config.max_gripper_pos
|
||||||
|
)
|
||||||
action[-1] = gripper_action_value.item()
|
action[-1] = gripper_action_value.item()
|
||||||
return action
|
return action
|
||||||
|
|
||||||
@@ -1138,7 +1140,7 @@ class EEObservationWrapper(gym.ObservationWrapper):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Initialize kinematics instance for the appropriate robot type
|
# Initialize kinematics instance for the appropriate robot type
|
||||||
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100")
|
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so101")
|
||||||
self.kinematics = RobotKinematics(robot_type)
|
self.kinematics = RobotKinematics(robot_type)
|
||||||
self.fk_function = self.kinematics.fk_gripper_tip
|
self.fk_function = self.kinematics.fk_gripper_tip
|
||||||
|
|
||||||
@@ -1152,16 +1154,10 @@ class EEObservationWrapper(gym.ObservationWrapper):
|
|||||||
Returns:
|
Returns:
|
||||||
Enhanced observation with end-effector pose information.
|
Enhanced observation with end-effector pose information.
|
||||||
"""
|
"""
|
||||||
current_joint_pos = self.unwrapped._get_observation()["observation.state"]
|
current_joint_pos = self.unwrapped._get_observation()["agent_pos"]
|
||||||
|
|
||||||
current_ee_pos = self.fk_function(current_joint_pos)
|
current_ee_pos = self.fk_function(current_joint_pos)
|
||||||
observation["observation.state"] = torch.cat(
|
observation["agent_pos"] = np.concatenate([observation["agent_pos"], current_ee_pos[:3, 3]], -1)
|
||||||
[
|
|
||||||
observation["observation.state"],
|
|
||||||
torch.from_numpy(current_ee_pos[:3, 3]),
|
|
||||||
],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
|
||||||
@@ -1178,9 +1174,7 @@ class BaseLeaderControlWrapper(gym.Wrapper):
|
|||||||
where the human can control a leader robot to guide the follower robot's movements.
|
where the human can control a leader robot to guide the follower robot's movements.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, env, teleop_device, use_geared_leader_arm: bool = False, use_gripper=False):
|
||||||
self, env, teleop_device, use_geared_leader_arm: bool = False, use_gripper=False
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Initialize the base leader control wrapper.
|
Initialize the base leader control wrapper.
|
||||||
|
|
||||||
@@ -1322,10 +1316,12 @@ class BaseLeaderControlWrapper(gym.Wrapper):
|
|||||||
|
|
||||||
if self.use_gripper:
|
if self.use_gripper:
|
||||||
if self.prev_leader_gripper is None:
|
if self.prev_leader_gripper is None:
|
||||||
self.prev_leader_gripper = np.clip(leader_pos[-1], 0, self.robot_follower.config.max_gripper_pos)
|
self.prev_leader_gripper = np.clip(
|
||||||
|
leader_pos[-1], 0, self.robot_follower.config.max_gripper_pos
|
||||||
|
)
|
||||||
|
|
||||||
# Get gripper action delta based on leader pose
|
# Get gripper action delta based on leader pose
|
||||||
leader_gripper = leader_pos[-1]
|
leader_gripper = leader_pos[-1]
|
||||||
# follower_gripper = follower_pos[-1]
|
# follower_gripper = follower_pos[-1]
|
||||||
gripper_delta = leader_gripper - self.prev_leader_gripper
|
gripper_delta = leader_gripper - self.prev_leader_gripper
|
||||||
|
|
||||||
@@ -1342,7 +1338,7 @@ class BaseLeaderControlWrapper(gym.Wrapper):
|
|||||||
action = np.append(action, gripper_action)
|
action = np.append(action, gripper_action)
|
||||||
# action_intervention = np.append(action_intervention, gripper_delta)
|
# action_intervention = np.append(action_intervention, gripper_delta)
|
||||||
|
|
||||||
return action # , action_intervention
|
return action # , action_intervention
|
||||||
|
|
||||||
def _handle_leader_teleoperation(self):
|
def _handle_leader_teleoperation(self):
|
||||||
"""
|
"""
|
||||||
@@ -1389,7 +1385,11 @@ class BaseLeaderControlWrapper(gym.Wrapper):
|
|||||||
info["is_intervention"] = is_intervention
|
info["is_intervention"] = is_intervention
|
||||||
info["action_intervention"] = action if is_intervention else None
|
info["action_intervention"] = action if is_intervention else None
|
||||||
|
|
||||||
self.prev_leader_gripper = np.clip(self.robot_leader.bus.sync_read("Present_Position")["gripper"], 0, self.robot_follower.config.max_gripper_pos)
|
self.prev_leader_gripper = np.clip(
|
||||||
|
self.robot_leader.bus.sync_read("Present_Position")["gripper"],
|
||||||
|
0,
|
||||||
|
self.robot_follower.config.max_gripper_pos,
|
||||||
|
)
|
||||||
|
|
||||||
# Check for success or manual termination
|
# Check for success or manual termination
|
||||||
success = self.keyboard_events["episode_success"]
|
success = self.keyboard_events["episode_success"]
|
||||||
@@ -1569,7 +1569,7 @@ class GearedLeaderAutomaticControlWrapper(BaseLeaderControlWrapper):
|
|||||||
self.ee_error_over_time_queue[-1] - self.previous_ee_error_over_time_over_time
|
self.ee_error_over_time_queue[-1] - self.previous_ee_error_over_time_over_time
|
||||||
)
|
)
|
||||||
self.previous_ee_error_over_time_over_time = self.ee_error_over_time_queue[-1]
|
self.previous_ee_error_over_time_over_time = self.ee_error_over_time_queue[-1]
|
||||||
|
|
||||||
self.intervention_threshold = 0.02
|
self.intervention_threshold = 0.02
|
||||||
|
|
||||||
# Determine if intervention should start or stop based on the thresholds set in the constructor
|
# Determine if intervention should start or stop based on the thresholds set in the constructor
|
||||||
@@ -1890,8 +1890,8 @@ def make_robot_env(cfg: EnvConfig) -> gym.Env:
|
|||||||
env = AddJointVelocityToObservation(env=env, fps=cfg.fps)
|
env = AddJointVelocityToObservation(env=env, fps=cfg.fps)
|
||||||
if cfg.wrapper.add_current_to_observation:
|
if cfg.wrapper.add_current_to_observation:
|
||||||
env = AddCurrentToObservation(env=env)
|
env = AddCurrentToObservation(env=env)
|
||||||
if False and cfg.wrapper.add_ee_pose_to_observation:
|
if cfg.wrapper.add_ee_pose_to_observation:
|
||||||
env = EEObservationWrapper(env=env, ee_pose_limits=cfg.wrapper.ee_action_space_params.bounds)
|
env = EEObservationWrapper(env=env, ee_pose_limits=robot.end_effector_bounds)
|
||||||
|
|
||||||
env = ConvertToLeRobotObservation(env=env, device=cfg.device)
|
env = ConvertToLeRobotObservation(env=env, device=cfg.device)
|
||||||
|
|
||||||
@@ -1910,9 +1910,9 @@ def make_robot_env(cfg: EnvConfig) -> gym.Env:
|
|||||||
if cfg.wrapper:
|
if cfg.wrapper:
|
||||||
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
||||||
if cfg.wrapper.use_gripper:
|
if cfg.wrapper.use_gripper:
|
||||||
env = GripperActionWrapper(
|
# env = GripperActionWrapper(
|
||||||
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
|
# env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
|
||||||
)
|
# )
|
||||||
if cfg.wrapper.gripper_penalty is not None:
|
if cfg.wrapper.gripper_penalty is not None:
|
||||||
env = GripperPenaltyWrapper(
|
env = GripperPenaltyWrapper(
|
||||||
env=env,
|
env=env,
|
||||||
@@ -1922,7 +1922,9 @@ def make_robot_env(cfg: EnvConfig) -> gym.Env:
|
|||||||
# Control mode specific wrappers
|
# Control mode specific wrappers
|
||||||
control_mode = cfg.wrapper.control_mode
|
control_mode = cfg.wrapper.control_mode
|
||||||
if control_mode == "gamepad":
|
if control_mode == "gamepad":
|
||||||
assert isinstance(teleop_device, GamepadTeleop), "teleop_device must be an instance of GamepadTeleop for gamepad control mode"
|
assert isinstance(teleop_device, GamepadTeleop), (
|
||||||
|
"teleop_device must be an instance of GamepadTeleop for gamepad control mode"
|
||||||
|
)
|
||||||
env = GamepadControlWrapper(
|
env = GamepadControlWrapper(
|
||||||
env=env,
|
env=env,
|
||||||
teleop_device=teleop_device,
|
teleop_device=teleop_device,
|
||||||
@@ -2117,11 +2119,10 @@ def record_dataset(env, policy, cfg):
|
|||||||
really_done = success_steps_collected >= cfg.number_of_steps_after_success
|
really_done = success_steps_collected >= cfg.number_of_steps_after_success
|
||||||
|
|
||||||
frame["next.done"] = np.array([really_done], dtype=bool)
|
frame["next.done"] = np.array([really_done], dtype=bool)
|
||||||
frame["task"] = cfg.task
|
|
||||||
frame["complementary_info.discrete_penalty"] = torch.tensor(
|
frame["complementary_info.discrete_penalty"] = torch.tensor(
|
||||||
[info.get("discrete_penalty", 0.0)], dtype=torch.float32
|
[info.get("discrete_penalty", 0.0)], dtype=torch.float32
|
||||||
)
|
)
|
||||||
dataset.add_frame(frame)
|
dataset.add_frame(frame, task=cfg.task)
|
||||||
|
|
||||||
# Maintain consistent timing
|
# Maintain consistent timing
|
||||||
if cfg.fps:
|
if cfg.fps:
|
||||||
@@ -2233,7 +2234,7 @@ def main(cfg: EnvConfig):
|
|||||||
while num_episode < 10:
|
while num_episode < 10:
|
||||||
start_loop_s = time.perf_counter()
|
start_loop_s = time.perf_counter()
|
||||||
# Sample a new random action from the robot's action space.
|
# Sample a new random action from the robot's action space.
|
||||||
new_random_action = env.action_space.sample()
|
new_random_action = env.action_space.sample() * 0.0
|
||||||
# Update the smoothed action using an exponential moving average.
|
# Update the smoothed action using an exponential moving average.
|
||||||
smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
|
smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
|
||||||
|
|
||||||
|
|||||||
@@ -91,6 +91,7 @@ from torch import nn
|
|||||||
from torch.multiprocessing import Queue
|
from torch.multiprocessing import Queue
|
||||||
from torch.optim.optimizer import Optimizer
|
from torch.optim.optimizer import Optimizer
|
||||||
|
|
||||||
|
from lerobot.common.cameras import so100_follower_end_effector # noqa: F401
|
||||||
from lerobot.common.constants import (
|
from lerobot.common.constants import (
|
||||||
CHECKPOINTS_DIR,
|
CHECKPOINTS_DIR,
|
||||||
LAST_CHECKPOINT_LINK,
|
LAST_CHECKPOINT_LINK,
|
||||||
@@ -101,6 +102,8 @@ from lerobot.common.datasets.factory import make_dataset
|
|||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||||
|
from lerobot.common.robots import so100_follower_end_effector # noqa: F401
|
||||||
|
from lerobot.common.teleoperators import gamepad, so100_leader # noqa: F401
|
||||||
from lerobot.common.utils.random_utils import set_seed
|
from lerobot.common.utils.random_utils import set_seed
|
||||||
from lerobot.common.utils.train_utils import (
|
from lerobot.common.utils.train_utils import (
|
||||||
get_step_checkpoint_dir,
|
get_step_checkpoint_dir,
|
||||||
|
|||||||
Reference in New Issue
Block a user