Change HILSerlRobotEnvConfig to inherit from EnvConfig

Added support for hil_serl classifier to be trained with train.py
run classifier training by python lerobot/scripts/train.py --policy.type=hilserl_classifier
fixes in find_joint_limits, control_robot, end_effector_control_utils
This commit is contained in:
Michel Aractingi
2025-03-27 10:23:14 +01:00
parent db897a1619
commit b69132c79d
13 changed files with 388 additions and 340 deletions

View File

@@ -279,10 +279,7 @@ def record(
if not robot.is_connected:
robot.connect()
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
if reset_follower:
initial_position = robot.follower_arms["main"].read("Present_Position")
listener, events = init_keyboard_listener(assign_rewards=cfg.assign_rewards)
# Execute a few seconds without recording to:
# 1. teleoperate the robot to move it in starting position if no policy provided,

View File

@@ -1,13 +1,13 @@
import argparse
import logging
import sys
import time
import numpy as np
import torch
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.scripts.server.kinematics import RobotKinematics
import logging
import time
import torch
import numpy as np
import argparse
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.scripts.server.gym_manipulator import make_robot_env, HILSerlRobotEnvConfig
from lerobot.common.robot_devices.robots.configs import RobotConfig
logging.basicConfig(level=logging.INFO)
@@ -458,12 +458,13 @@ class GamepadControllerHID(InputController):
def test_forward_kinematics(robot, fps=10):
logging.info("Testing Forward Kinematics")
timestep = time.perf_counter()
kinematics = RobotKinematics(robot.robot_type)
while time.perf_counter() - timestep < 60.0:
loop_start_time = time.perf_counter()
robot.teleop_step()
obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy()
ee_pos = RobotKinematics.fk_gripper_tip(joint_positions)
ee_pos = kinematics.fk_gripper_tip(joint_positions)
logging.info(f"EE Position: {ee_pos[:3, 3]}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
@@ -485,21 +486,19 @@ def test_inverse_kinematics(robot, fps=10):
def teleoperate_inverse_kinematics_with_leader(robot, fps=10):
logging.info("Testing Inverse Kinematics")
fk_func = RobotKinematics.fk_gripper_tip
kinematics = RobotKinematics(robot.robot_type)
timestep = time.perf_counter()
while time.perf_counter() - timestep < 60.0:
loop_start_time = time.perf_counter()
obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy()
ee_pos = fk_func(joint_positions)
ee_pos = kinematics.fk_gripper_tip(joint_positions)
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
leader_ee = fk_func(leader_joint_positions)
leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
desired_ee_pos = leader_ee
target_joint_state = RobotKinematics.ik(
joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func
)
target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True)
robot.send_action(torch.from_numpy(target_joint_state))
logging.info(f"Leader EE: {leader_ee[:3, 3]}, Follower EE: {ee_pos[:3, 3]}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
@@ -513,10 +512,10 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy()
fk_func = RobotKinematics.fk_gripper_tip
kinematics = RobotKinematics(robot.robot_type)
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
initial_leader_ee = fk_func(leader_joint_positions)
initial_leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
desired_ee_pos = np.diag(np.ones(4))
@@ -525,13 +524,13 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
# Get leader state for teleoperation
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
leader_ee = fk_func(leader_joint_positions)
leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
# Get current state
# obs = robot.capture_observation()
# joint_positions = obs["observation.state"].cpu().numpy()
joint_positions = robot.follower_arms["main"].read("Present_Position")
current_ee_pos = fk_func(joint_positions)
current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
# Calculate delta between leader and follower end-effectors
# Scaling factor can be adjusted for sensitivity
@@ -545,9 +544,7 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
if np.any(np.abs(ee_delta[:3, 3]) > 0.01):
# Compute joint targets via inverse kinematics
target_joint_state = RobotKinematics.ik(
joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func
)
target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True)
initial_leader_ee = leader_ee.copy()
@@ -580,7 +577,8 @@ def teleoperate_delta_inverse_kinematics(robot, controller, fps=10, bounds=None,
# Initial position capture
obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy()
current_ee_pos = fk_func(joint_positions)
kinematics = RobotKinematics(robot.robot_type)
current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
# Initialize desired position with current position
desired_ee_pos = np.eye(4) # Identity matrix
@@ -595,7 +593,7 @@ def teleoperate_delta_inverse_kinematics(robot, controller, fps=10, bounds=None,
# Get currrent robot state
joint_positions = robot.follower_arms["main"].read("Present_Position")
current_ee_pos = fk_func(joint_positions)
current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
# Get movement deltas from the controller
delta_x, delta_y, delta_z = controller.get_deltas()
@@ -612,9 +610,7 @@ def teleoperate_delta_inverse_kinematics(robot, controller, fps=10, bounds=None,
# Only send commands if there's actual movement
if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]):
# Compute joint targets via inverse kinematics
target_joint_state = RobotKinematics.ik(
joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func
)
target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True)
# Send command to robot
robot.send_action(torch.from_numpy(target_joint_state))
@@ -676,7 +672,17 @@ def teleoperate_gym_env(env, controller, fps: int = 30):
# Close the environment
env.close()
if __name__ == "__main__":
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.scripts.server.gym_manipulator import (
EEActionSpaceConfig,
EnvWrapperConfig,
HILSerlRobotEnvConfig,
make_robot_env,
)
parser = argparse.ArgumentParser(description="Test end-effector control")
parser.add_argument(
"--mode",
@@ -698,12 +704,6 @@ if __name__ == "__main__":
default="so100",
help="Robot type (so100, koch, aloha, etc.)",
)
parser.add_argument(
"--config-path",
type=str,
default=None,
help="Path to the config file in json format",
)
args = parser.parse_args()
@@ -725,7 +725,10 @@ if __name__ == "__main__":
if args.mode.startswith("keyboard"):
controller = KeyboardController(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
elif args.mode.startswith("gamepad"):
controller = GamepadController(x_step_size=0.02, y_step_size=0.02, z_step_size=0.05)
if sys.platform == "darwin":
controller = GamepadControllerHID(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
else:
controller = GamepadController(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
# Handle mode categories
if args.mode in ["keyboard", "gamepad"]:
@@ -734,12 +737,14 @@ if __name__ == "__main__":
elif args.mode in ["keyboard_gym", "gamepad_gym"]:
# Gym environment control modes
cfg = HILSerlRobotEnvConfig()
if args.config_path is not None:
cfg = HILSerlRobotEnvConfig.from_json(args.config_path)
cfg = HILSerlRobotEnvConfig(robot=robot_config, wrapper=EnvWrapperConfig())
cfg.wrapper.ee_action_space_params = EEActionSpaceConfig(
x_step_size=0.03, y_step_size=0.03, z_step_size=0.03, bounds=bounds
)
cfg.wrapper.ee_action_space_params.use_gamepad = False
cfg.device = "cpu"
env = make_robot_env(cfg, robot)
teleoperate_gym_env(env, controller, fps=args.fps)
teleoperate_gym_env(env, controller, fps=cfg.fps)
elif args.mode == "leader":
# Leader-follower modes don't use controllers

View File

@@ -63,9 +63,10 @@ def find_ee_bounds(
if time.perf_counter() - start_episode_t < 5:
continue
kinematics = RobotKinematics(robot.robot_type)
joint_positions = robot.follower_arms["main"].read("Present_Position")
print(f"Joint positions: {joint_positions}")
ee_list.append(RobotKinematics.fk_gripper_tip(joint_positions)[:3, 3])
ee_list.append(kinematics.fk_gripper_tip(joint_positions)[:3, 3])
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
@@ -81,20 +82,19 @@ def find_ee_bounds(
break
def make_robot(robot_type="so100", mock=True):
def make_robot(robot_type="so100"):
"""
Create a robot instance using the appropriate robot config class.
Args:
robot_type: Robot type string (e.g., "so100", "koch", "aloha")
mock: Whether to use mock mode for hardware (default: True)
Returns:
Robot instance
"""
# Get the appropriate robot config class based on robot_type
robot_config = RobotConfig.get_choice_class(robot_type)(mock=mock)
robot_config = RobotConfig.get_choice_class(robot_type)(mock=False)
robot_config.leader_arms["main"].port = leader_port
robot_config.follower_arms["main"].port = follower_port
@@ -122,18 +122,12 @@ if __name__ == "__main__":
default="so100",
help="Robot type (so100, koch, aloha, etc.)",
)
parser.add_argument(
"--mock",
type=int,
default=1,
help="Use mock mode for hardware simulation",
)
# Only parse known args, leaving robot config args for Hydra if used
args, _ = parser.parse_known_args()
args = parser.parse_args()
# Create robot with the appropriate config
robot = make_robot(args.robot_type, args.mock)
robot = make_robot(args.robot_type)
if args.mode == "joint":
find_joint_bounds(robot, args.control_time_s)

View File

@@ -1,45 +1,37 @@
import logging
import sys
import time
import sys
from dataclasses import dataclass
from threading import Lock
from typing import Annotated, Any, Dict, Tuple
from typing import Annotated, Any, Dict, Optional, Tuple
import gymnasium as gym
import numpy as np
import torch
import torchvision.transforms.functional as F # noqa: N812
import json
from dataclasses import dataclass
from lerobot.common.envs.utils import preprocess_observation
from lerobot.configs.train import TrainPipelineConfig
from lerobot.common.envs.configs import EnvConfig
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.robot_devices.control_utils import (
busy_wait,
is_headless,
# reset_follower_position,
reset_follower_position,
)
from typing import Optional
from lerobot.common.utils.utils import log_say
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.scripts.server.kinematics import RobotKinematics
from lerobot.scripts.server.maniskill_manipulator import ManiskillEnvConfig, make_maniskill
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.common.utils.utils import log_say
from lerobot.configs import parser
from lerobot.scripts.server.kinematics import RobotKinematics
logging.basicConfig(level=logging.INFO)
@dataclass
class EEActionSpaceConfig:
"""Configuration parameters for end-effector action space."""
x_step_size: float
y_step_size: float
y_step_size: float
z_step_size: float
bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds
use_gamepad: bool = False
@@ -48,6 +40,7 @@ class EEActionSpaceConfig:
@dataclass
class EnvWrapperConfig:
"""Configuration for environment wrappers."""
display_cameras: bool = False
delta_action: float = 0.1
use_relative_joint_positions: bool = True
@@ -64,28 +57,27 @@ class EnvWrapperConfig:
reward_classifier_config_file: Optional[str] = None
@EnvConfig.register_subclass(name="gym_manipulator")
@dataclass
class HILSerlRobotEnvConfig:
class HILSerlRobotEnvConfig(EnvConfig):
"""Configuration for the HILSerlRobotEnv environment."""
robot: RobotConfig
wrapper: EnvWrapperConfig
env_name: str = "real_robot"
robot: Optional[RobotConfig] = None
wrapper: Optional[EnvWrapperConfig] = None
fps: int = 10
mode: str = None # Either "record", "replay", None
repo_id: Optional[str] = None
dataset_root: Optional[str] = None
task: str = ""
num_episodes: int = 10 # only for record mode
num_episodes: int = 10 # only for record mode
episode: int = 0
device: str = "cuda"
push_to_hub: bool = True
pretrained_policy_name_or_path: Optional[str] = None
@classmethod
def from_json(cls, json_path: str):
with open(json_path, "r") as f:
config = json.load(f)
return cls(**config)
def gym_kwargs(self) -> dict:
return {}
class HILSerlRobotEnv(gym.Env):
"""
@@ -580,8 +572,7 @@ class ImageCropResizeWrapper(gym.Wrapper):
if key_crop not in self.env.observation_space.keys(): # noqa: SIM118
raise ValueError(f"Key {key_crop} not in observation space")
for key in crop_params_dict:
top, left, height, width = crop_params_dict[key]
new_shape = (top + height, left + width)
new_shape = (3, resize_size[0], resize_size[1])
self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape)
self.resize_size = resize_size
@@ -1097,9 +1088,7 @@ class ActionScaleWrapper(gym.ActionWrapper):
return action * self.scale_vector, is_intervention
def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv:
# def make_robot_env(cfg: TrainPipelineConfig) -> gym.vector.VectorEnv:
# def make_robot_env(cfg: ManiskillEnvConfig) -> gym.vector.VectorEnv:
def make_robot_env(cfg) -> gym.vector.VectorEnv:
"""
Factory function to create a vectorized robot environment.
@@ -1111,16 +1100,16 @@ def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv:
Returns:
A vectorized gym environment with all the necessary wrappers applied.
"""
if "maniskill" in cfg.name:
from lerobot.scripts.server.maniskill_manipulator import make_maniskill
# if "maniskill" in cfg.name:
# from lerobot.scripts.server.maniskill_manipulator import make_maniskill
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
env = make_maniskill(
cfg=cfg,
n_envs=1,
)
return env
robot = cfg.robot
# logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
# env = make_maniskill(
# cfg=cfg,
# n_envs=1,
# )
# return env
robot = make_robot_from_config(cfg.robot)
# Create base environment
env = HILSerlRobotEnv(
robot=robot,
@@ -1150,10 +1139,7 @@ def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv:
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
if cfg.wrapper.ee_action_space_params is not None:
env = EEActionWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
if (
cfg.wrapper.ee_action_space_params is not None
and cfg.wrapper.ee_action_space_params.use_gamepad
):
if cfg.wrapper.ee_action_space_params is not None and cfg.wrapper.ee_action_space_params.use_gamepad:
# env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
env = GamepadControlWrapper(
env=env,
@@ -1169,10 +1155,7 @@ def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv:
reset_pose=cfg.wrapper.fixed_reset_joint_positions,
reset_time_s=cfg.wrapper.reset_time_s,
)
if (
cfg.wrapper.ee_action_space_params is None
and cfg.wrapper.joint_masking_action_space is not None
):
if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
env = BatchCompitableWrapper(env=env)
@@ -1180,7 +1163,10 @@ def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv:
def get_classifier(cfg):
if cfg.wrapper.reward_classifier_pretrained_path is None or cfg.wrapper.reward_classifier_config_file is None:
if (
cfg.wrapper.reward_classifier_pretrained_path is None
or cfg.wrapper.reward_classifier_config_file is None
):
return None
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
@@ -1258,7 +1244,8 @@ def record_dataset(env, policy, cfg: HILSerlRobotEnvConfig):
# Record episodes
episode_index = 0
while episode_index < cfg.record_num_episodes:
recorded_action = None
while episode_index < cfg.num_episodes:
obs, _ = env.reset()
start_episode_t = time.perf_counter()
log_say(f"Recording episode {episode_index}", play_sounds=True)
@@ -1279,16 +1266,19 @@ def record_dataset(env, policy, cfg: HILSerlRobotEnvConfig):
break
# For teleop, get action from intervention
if policy is None:
action = {"action": info["action_intervention"].cpu().squeeze(0).float()}
recorded_action = {
"action": info["action_intervention"].cpu().squeeze(0).float() if policy is None else action
}
# Process observation for dataset
obs = {k: v.cpu().squeeze(0).float() for k, v in obs.items()}
obs["observation.images.side"] = torch.clamp(obs["observation.images.side"], 0, 1)
# Add frame to dataset
frame = {**obs, **action}
frame["next.reward"] = reward
frame["next.done"] = terminated or truncated
frame = {**obs, **recorded_action}
frame["next.reward"] = np.array([reward], dtype=np.float32)
frame["next.done"] = np.array([terminated or truncated], dtype=bool)
frame["task"] = cfg.task
dataset.add_frame(frame)
# Maintain consistent timing
@@ -1309,9 +1299,9 @@ def record_dataset(env, policy, cfg: HILSerlRobotEnvConfig):
episode_index += 1
# Finalize dataset
dataset.consolidate(run_compute_stats=True)
# dataset.consolidate(run_compute_stats=True)
if cfg.push_to_hub:
dataset.push_to_hub(cfg.repo_id)
dataset.push_to_hub()
def replay_episode(env, repo_id, root=None, episode=0):
@@ -1334,82 +1324,69 @@ def replay_episode(env, repo_id, root=None, episode=0):
busy_wait(1 / 10 - dt_s)
# @parser.wrap()
# def main(cfg):
@parser.wrap()
def main(cfg: EnvConfig):
env = make_robot_env(cfg)
# robot = make_robot_from_config(cfg.robot)
if cfg.mode == "record":
policy = None
if cfg.pretrained_policy_name_or_path is not None:
from lerobot.common.policies.sac.modeling_sac import SACPolicy
# reward_classifier = None #get_classifier(
# # cfg.wrapper.reward_classifier_pretrained_path, cfg.wrapper.reward_classifier_config_file
# # )
# user_relative_joint_positions = True
policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path)
policy.to(cfg.device)
policy.eval()
# env = make_robot_env(cfg, robot)
record_dataset(
env,
policy=None,
cfg=cfg,
)
exit()
# if cfg.mode == "record":
# policy = None
# if cfg.pretrained_policy_name_or_path is not None:
# from lerobot.common.policies.sac.modeling_sac import SACPolicy
if cfg.mode == "replay":
replay_episode(
env,
cfg.replay_repo_id,
root=cfg.dataset_root,
episode=cfg.replay_episode,
)
exit()
# policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path)
# policy.to(cfg.device)
# policy.eval()
env.reset()
# record_dataset(
# env,
# cfg.repo_id,
# root=cfg.dataset_root,
# num_episodes=cfg.num_episodes,
# fps=cfg.fps,
# task_description=cfg.task,
# policy=policy,
# )
# exit()
# Retrieve the robot's action space for joint commands.
action_space_robot = env.action_space.spaces[0]
# if cfg.mode == "replay":
# replay_episode(
# env,
# cfg.replay_repo_id,
# root=cfg.dataset_root,
# episode=cfg.replay_episode,
# )
# exit()
# Initialize the smoothed action as a random sample.
smoothed_action = action_space_robot.sample()
# env.reset()
# Smoothing coefficient (alpha) defines how much of the new random sample to mix in.
# A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth.
alpha = 1.0
# # Retrieve the robot's action space for joint commands.
# action_space_robot = env.action_space.spaces[0]
num_episode = 0
sucesses = []
while num_episode < 20:
start_loop_s = time.perf_counter()
# Sample a new random action from the robot's action space.
new_random_action = action_space_robot.sample()
# Update the smoothed action using an exponential moving average.
smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
# # Initialize the smoothed action as a random sample.
# smoothed_action = action_space_robot.sample()
# Execute the step: wrap the NumPy action in a torch tensor.
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False))
if terminated or truncated:
sucesses.append(reward)
env.reset()
num_episode += 1
# # Smoothing coefficient (alpha) defines how much of the new random sample to mix in.
# # A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth.
# alpha = 1.0
dt_s = time.perf_counter() - start_loop_s
busy_wait(1 / cfg.fps - dt_s)
# num_episode = 0
# sucesses = []
# while num_episode < 20:
# start_loop_s = time.perf_counter()
# # Sample a new random action from the robot's action space.
# new_random_action = action_space_robot.sample()
# # Update the smoothed action using an exponential moving average.
# smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
logging.info(f"Success after 20 steps {sucesses}")
logging.info(f"success rate {sum(sucesses) / len(sucesses)}")
# # Execute the step: wrap the NumPy action in a torch tensor.
# obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False))
# if terminated or truncated:
# sucesses.append(reward)
# env.reset()
# num_episode += 1
# dt_s = time.perf_counter() - start_loop_s
# busy_wait(1 / cfg.fps - dt_s)
# logging.info(f"Success after 20 steps {sucesses}")
# logging.info(f"success rate {sum(sucesses) / len(sucesses)}")
# if __name__ == "__main__":
# main()
if __name__ == "__main__":
make_robot_env()
main()

View File

@@ -15,12 +15,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import shutil
import time
from concurrent.futures import ThreadPoolExecutor
from pprint import pformat
import os
from pathlib import Path
from pprint import pformat
import draccus
import grpc
@@ -30,35 +30,42 @@ import hilserl_pb2_grpc # type: ignore
import torch
from termcolor import colored
from torch import nn
from torch.multiprocessing import Queue
from torch.optim.optimizer import Optimizer
from lerobot.common.constants import (
CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK,
PRETRAINED_MODEL_DIR,
TRAINING_STATE_DIR,
TRAINING_STEP,
)
from lerobot.common.datasets.factory import make_dataset
from lerobot.configs.train import TrainPipelineConfig
from lerobot.configs import parser
# TODO: Remove the import of maniskill
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy, SACConfig
from lerobot.common.policies.sac.modeling_sac import SACConfig, SACPolicy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_state as utils_load_training_state,
save_checkpoint,
update_last_checkpoint,
save_training_state,
update_last_checkpoint,
)
from lerobot.common.utils.train_utils import (
load_training_state as utils_load_training_state,
)
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.utils import (
format_big_number,
get_safe_torch_device,
init_logging,
)
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.server import learner_service
from lerobot.scripts.server.buffer import (
ReplayBuffer,
@@ -70,47 +77,39 @@ from lerobot.scripts.server.buffer import (
state_to_bytes,
)
from lerobot.scripts.server.utils import setup_process_handlers
from lerobot.common.constants import (
CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK,
PRETRAINED_MODEL_DIR,
TRAINING_STATE_DIR,
TRAINING_STEP,
)
def handle_resume_logic(cfg: TrainPipelineConfig) -> TrainPipelineConfig:
"""
Handle the resume logic for training.
If resume is True:
- Verifies that a checkpoint exists
- Loads the checkpoint configuration
- Logs resumption details
- Returns the checkpoint configuration
If resume is False:
- Checks if an output directory exists (to prevent accidental overwriting)
- Returns the original configuration
Args:
cfg (TrainPipelineConfig): The training configuration
Returns:
TrainPipelineConfig: The updated configuration
Raises:
RuntimeError: If resume is True but no checkpoint found, or if resume is False but directory exists
"""
out_dir = cfg.output_dir
# Case 1: Not resuming, but need to check if directory exists to prevent overwrites
if not cfg.resume:
checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
if os.path.exists(checkpoint_dir):
raise RuntimeError(
f"Output directory {checkpoint_dir} already exists. "
"Use `resume=true` to resume training."
f"Output directory {checkpoint_dir} already exists. Use `resume=true` to resume training."
)
return cfg
@@ -131,7 +130,7 @@ def handle_resume_logic(cfg: TrainPipelineConfig) -> TrainPipelineConfig:
# Load config using Draccus
checkpoint_cfg_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR, "train_config.json")
checkpoint_cfg = TrainPipelineConfig.from_pretrained(checkpoint_cfg_path)
# Ensure resume flag is set in returned config
checkpoint_cfg.resume = True
return checkpoint_cfg
@@ -143,11 +142,11 @@ def load_training_state(
):
"""
Loads the training state (optimizers, step count, etc.) from a checkpoint.
Args:
cfg (TrainPipelineConfig): Training configuration
optimizers (Optimizer | dict): Optimizers to load state into
Returns:
tuple: (optimization_step, interaction_step) or (None, None) if not resuming
"""
@@ -156,23 +155,23 @@ def load_training_state(
# Construct path to the last checkpoint directory
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
logging.info(f"Loading training state from {checkpoint_dir}")
try:
# Use the utility function from train_utils which loads the optimizer state
step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None)
# Load interaction step separately from training_state.pt
training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt")
interaction_step = 0
if os.path.exists(training_state_path):
training_state = torch.load(training_state_path, weights_only=False)
interaction_step = training_state.get("interaction_step", 0)
logging.info(f"Resuming from step {step}, interaction step {interaction_step}")
return step, interaction_step
except Exception as e:
logging.error(f"Failed to load training state: {e}")
return None, None
@@ -181,7 +180,7 @@ def load_training_state(
def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None:
"""
Log information about the training process.
Args:
cfg (TrainPipelineConfig): Training configuration
policy (nn.Module): Policy model
@@ -189,7 +188,6 @@ def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None:
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.policy.online_steps=}")
@@ -197,19 +195,15 @@ def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None:
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
def initialize_replay_buffer(
cfg: TrainPipelineConfig,
device: str,
storage_device: str
) -> ReplayBuffer:
def initialize_replay_buffer(cfg: TrainPipelineConfig, device: str, storage_device: str) -> ReplayBuffer:
"""
Initialize a replay buffer, either empty or from a dataset if resuming.
Args:
cfg (TrainPipelineConfig): Training configuration
device (str): Device to store tensors on
storage_device (str): Device for storage optimization
Returns:
ReplayBuffer: Initialized replay buffer
"""
@@ -224,7 +218,7 @@ def initialize_replay_buffer(
logging.info("Resume training load the online dataset")
dataset_path = os.path.join(cfg.output_dir, "dataset")
# NOTE: In RL is possible to not have a dataset.
repo_id = None
if cfg.dataset is not None:
@@ -250,13 +244,13 @@ def initialize_offline_replay_buffer(
) -> ReplayBuffer:
"""
Initialize an offline replay buffer from a dataset.
Args:
cfg (TrainPipelineConfig): Training configuration
device (str): Device to store tensors on
storage_device (str): Device for storage optimization
active_action_dims (list[int] | None): Active action dimensions for masking
Returns:
ReplayBuffer: Initialized offline replay buffer
"""
@@ -314,7 +308,7 @@ def start_learner_threads(
) -> None:
"""
Start the learner threads for training.
Args:
cfg (TrainPipelineConfig): Training configuration
wandb_logger (WandBLogger | None): Logger for metrics
@@ -512,17 +506,19 @@ def add_actor_information_and_train(
logging.info("Initializing policy")
# Get checkpoint dir for resuming
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None
checkpoint_dir = (
os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None
)
pretrained_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR) if checkpoint_dir else None
policy: SACPolicy = make_policy(
cfg=cfg.policy,
# ds_meta=cfg.dataset,
env_cfg=cfg.env
env_cfg=cfg.env,
)
# Update the policy config with the grad_clip_norm value from training config if it exists
clip_grad_norm_value:float = cfg.policy.grad_clip_norm
clip_grad_norm_value: float = cfg.policy.grad_clip_norm
# compile policy
policy = torch.compile(policy)
@@ -536,7 +532,7 @@ def add_actor_information_and_train(
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy)
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
log_training_info(cfg=cfg, policy= policy)
log_training_info(cfg=cfg, policy=policy)
replay_buffer = initialize_replay_buffer(cfg, device, storage_device)
batch_size = cfg.batch_size
@@ -615,14 +611,10 @@ def add_actor_information_and_train(
interaction_message = bytes_to_python_object(interaction_message)
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
interaction_message["Interaction step"] += interaction_step_shift
# Log interaction messages with WandB if available
if wandb_logger:
wandb_logger.log_dict(
d=interaction_message,
mode="train",
custom_step_key="Interaction step"
)
wandb_logger.log_dict(d=interaction_message, mode="train", custom_step_key="Interaction step")
logging.debug("[LEARNER] Received interactions")
@@ -636,7 +628,9 @@ def add_actor_information_and_train(
if dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
batch = concatenate_batch_transitions(left_batch_transitions=batch, right_batch_transition=batch_offline)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
actions = batch["action"]
rewards = batch["reward"]
@@ -759,14 +753,10 @@ def add_actor_information_and_train(
if offline_replay_buffer is not None:
training_infos["offline_replay_buffer_size"] = len(offline_replay_buffer)
training_infos["Optimization step"] = optimization_step
# Log training metrics
if wandb_logger:
wandb_logger.log_dict(
d=training_infos,
mode="train",
custom_step_key="Optimization step"
)
wandb_logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step")
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
@@ -795,29 +785,19 @@ def add_actor_information_and_train(
interaction_step = (
interaction_message["Interaction step"] if interaction_message is not None else 0
)
# Create checkpoint directory
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step)
# Save checkpoint
save_checkpoint(
checkpoint_dir,
optimization_step,
cfg,
policy,
optimizers,
scheduler=None
)
save_checkpoint(checkpoint_dir, optimization_step, cfg, policy, optimizers, scheduler=None)
# Save interaction step manually
training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR)
os.makedirs(training_state_dir, exist_ok=True)
training_state = {
"step": optimization_step,
"interaction_step": interaction_step
}
training_state = {"step": optimization_step, "interaction_step": interaction_step}
torch.save(training_state, os.path.join(training_state_dir, "training_state.pt"))
# Update the "last" symlink
update_last_checkpoint(checkpoint_dir)
@@ -826,17 +806,13 @@ def add_actor_information_and_train(
dataset_dir = os.path.join(cfg.output_dir, "dataset")
if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir):
shutil.rmtree(dataset_dir)
# Save dataset
# NOTE: Handle the case where the dataset repo id is not specified in the config
# eg. RL training without demonstrations data
# eg. RL training without demonstrations data
repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id
replay_buffer.to_lerobot_dataset(
repo_id=repo_id_buffer_save,
fps=fps,
root=dataset_dir
)
replay_buffer.to_lerobot_dataset(repo_id=repo_id_buffer_save, fps=fps, root=dataset_dir)
if offline_replay_buffer is not None:
dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline")
if os.path.exists(dataset_offline_dir) and os.path.isdir(dataset_offline_dir):
@@ -882,9 +858,7 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
params=policy.actor.parameters_to_optimize,
lr=cfg.policy.actor_lr,
)
optimizer_critic = torch.optim.Adam(
params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr
)
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
lr_scheduler = None
optimizers = {
@@ -898,19 +872,19 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
def train(cfg: TrainPipelineConfig, job_name: str | None = None):
"""
Main training function that initializes and runs the training process.
Args:
cfg (TrainPipelineConfig): The training configuration
job_name (str | None, optional): Job name for logging. Defaults to None.
"""
cfg.validate()
# if cfg.output_dir is None:
# raise ValueError("Output directory must be specified in config")
if job_name is None:
job_name = cfg.job_name
if job_name is None:
raise ValueError("Job name must be specified either in config or as a parameter")
@@ -920,11 +894,12 @@ def train(cfg: TrainPipelineConfig, job_name: str | None = None):
# Setup WandB logging if enabled
if cfg.wandb.enable and cfg.wandb.project:
from lerobot.common.utils.wandb_utils import WandBLogger
wandb_logger = WandBLogger(cfg)
else:
wandb_logger = None
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
# Handle resume logic
cfg = handle_resume_logic(cfg)
@@ -944,9 +919,9 @@ def train(cfg: TrainPipelineConfig, job_name: str | None = None):
@parser.wrap()
def train_cli(cfg: TrainPipelineConfig):
if not use_threads(cfg):
import torch.multiprocessing as mp
mp.set_start_method("spawn")
# Use the job_name from the config

View File

@@ -122,6 +122,9 @@ def make_optimizer_and_scheduler(cfg, policy):
optimizer = VQBeTOptimizer(policy, cfg)
lr_scheduler = VQBeTScheduler(optimizer, cfg)
elif cfg.policy.name == "hilserl_classifier":
optimizer = torch.optim.AdamW(policy.parameters(), cfg.policy.learning_rate)
lr_scheduler = None
else:
raise NotImplementedError()

View File

@@ -16,7 +16,6 @@ import time
from contextlib import nullcontext
from pprint import pformat
import hydra
import numpy as np
import torch
import torch.nn as nn
@@ -32,11 +31,8 @@ from tqdm import tqdm
from lerobot.common.datasets.factory import resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
ClassifierConfig,
)
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.common.utils.utils import (
format_big_number,
@@ -296,8 +292,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
init_logging()
logging.info(OmegaConf.to_yaml(cfg))
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
# Initialize training environment
device = get_safe_torch_device(cfg.device, log=True)
set_global_seed(cfg.seed)