forked from tangger/lerobot
Fix linter issue
This commit is contained in:
@@ -42,7 +42,7 @@ Replace the `dataset_repo_id` field with the identifier for your dataset, which
|
||||
dataset_repo_id: "my_dataset_repo_id"
|
||||
## Typical logs and metrics
|
||||
```
|
||||
When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you config it correctly and your config is not overrided by other files. The final configuration will also be saved with the checkpoint.
|
||||
When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you config it correctly and your config is not overriden by other files. The final configuration will also be saved with the checkpoint.
|
||||
|
||||
After that, you will see training log like this one:
|
||||
|
||||
|
||||
@@ -262,8 +262,6 @@ def control_loop(
|
||||
while timestamp < control_time_s:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
current_joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
|
||||
if teleoperate:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
else:
|
||||
@@ -326,7 +324,7 @@ def reset_follower_position(robot: Robot, target_position):
|
||||
current_position = robot.follower_arms["main"].read("Present_Position")
|
||||
trajectory = torch.from_numpy(
|
||||
np.linspace(current_position, target_position, 50)
|
||||
) # NOTE: 30 is just an aribtrary number
|
||||
) # NOTE: 30 is just an arbitrary number
|
||||
for pose in trajectory:
|
||||
robot.send_action(pose)
|
||||
busy_wait(0.015)
|
||||
|
||||
@@ -1,412 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Evaluate a policy by running rollouts on the real robot and computing metrics.
|
||||
|
||||
Usage examples: evaluate a checkpoint from the LeRobot training script for 10 episodes.
|
||||
|
||||
```
|
||||
python lerobot/scripts/eval_on_robot.py \
|
||||
-p outputs/train/model/checkpoints/005000/pretrained_model \
|
||||
eval.n_episodes=10
|
||||
```
|
||||
|
||||
Test reward classifier with teleoperation (you need to press space to take over)
|
||||
```
|
||||
python lerobot/scripts/eval_on_robot.py \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--reward-classifier-pretrained-path outputs/classifier/checkpoints/best/pretrained_model \
|
||||
--reward-classifier-config-file lerobot/configs/policy/hilserl_classifier.yaml \
|
||||
--display-cameras 1
|
||||
```
|
||||
|
||||
**NOTE** (michel-aractingi): This script is incomplete and it is being prepared
|
||||
for running training on the real robot.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.robot_devices.control_utils import (
|
||||
busy_wait,
|
||||
is_headless,
|
||||
reset_follower_position,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.factory import Robot, make_robot
|
||||
from lerobot.common.utils.utils import (
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
log_say,
|
||||
)
|
||||
|
||||
|
||||
def get_classifier(pretrained_path, config_path):
|
||||
if pretrained_path is None or config_path is None:
|
||||
return
|
||||
|
||||
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
|
||||
ClassifierConfig,
|
||||
)
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
||||
Classifier,
|
||||
)
|
||||
|
||||
cfg = init_hydra_config(config_path)
|
||||
|
||||
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
||||
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
|
||||
model = Classifier(classifier_config)
|
||||
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
|
||||
model = model.to("mps")
|
||||
return model
|
||||
|
||||
|
||||
def rollout(
|
||||
robot: Robot,
|
||||
policy: Policy,
|
||||
reward_classifier,
|
||||
fps: int,
|
||||
control_time_s: float = 20,
|
||||
use_amp: bool = True,
|
||||
display_cameras: bool = False,
|
||||
) -> dict:
|
||||
"""Run a batched policy rollout on the real robot.
|
||||
|
||||
The return dictionary contains:
|
||||
"robot": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation
|
||||
keys. NOTE the that this has an extra sequence element relative to the other keys in the
|
||||
dictionary. This is because an extra observation is included for after the environment is
|
||||
terminated or truncated.
|
||||
"action": A (batch, sequence, action_dim) tensor of actions applied based on the observations (not
|
||||
including the last observations).
|
||||
"reward": A (batch, sequence) tensor of rewards received for applying the actions.
|
||||
"success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon
|
||||
environment termination/truncation).
|
||||
"done": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
|
||||
the first True is followed by True's all the way till the end. This can be used for masking
|
||||
extraneous elements from the sequences above.
|
||||
|
||||
Args:
|
||||
robot: The robot class that defines the interface with the real robot.
|
||||
policy: The policy. Must be a PyTorch nn module.
|
||||
|
||||
Returns:
|
||||
The dictionary described above.
|
||||
"""
|
||||
# TODO (michel-aractingi): Infer the device from policy parameters when policy is added
|
||||
# assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module."
|
||||
# device = get_device_from_parameters(policy)
|
||||
|
||||
# define keyboard listener
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
# Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready.
|
||||
# policy.reset()
|
||||
|
||||
# NOTE: sorting to make sure the key sequence is the same during training and testing.
|
||||
observation = robot.capture_observation()
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
image_keys.sort()
|
||||
|
||||
all_actions = []
|
||||
all_rewards = []
|
||||
all_successes = []
|
||||
|
||||
start_episode_t = time.perf_counter()
|
||||
init_pos = robot.follower_arms["main"].read("Present_Position")
|
||||
timestamp = 0.0
|
||||
while timestamp < control_time_s:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
# Apply the next action.
|
||||
while events["pause_policy"] and not events["human_intervention_step"]:
|
||||
busy_wait(0.5)
|
||||
|
||||
if events["human_intervention_step"]:
|
||||
# take over the robot's actions
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
action = action["action"] # teleop step returns torch tensors but in a dict
|
||||
else:
|
||||
# explore with policy
|
||||
with torch.inference_mode():
|
||||
# TODO (michel-aractingi) replace this part with policy (predict_action)
|
||||
action = robot.follower_arms["main"].read("Present_Position")
|
||||
action = torch.from_numpy(action)
|
||||
robot.send_action(action)
|
||||
# action = predict_action(observation, policy, device, use_amp)
|
||||
|
||||
observation = robot.capture_observation()
|
||||
images = []
|
||||
for key in image_keys:
|
||||
if display_cameras:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
images.append(observation[key].to("mps"))
|
||||
|
||||
reward = reward_classifier.predict_reward(images) if reward_classifier is not None else 0.0
|
||||
all_rewards.append(reward)
|
||||
|
||||
# print("REWARD : ", reward)
|
||||
|
||||
all_actions.append(action)
|
||||
all_successes.append(torch.tensor([False]))
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
events["human_intervention_step"] = False
|
||||
events["pause_policy"] = False
|
||||
break
|
||||
|
||||
reset_follower_position(robot, target_position=init_pos)
|
||||
|
||||
dones = torch.tensor([False] * len(all_actions))
|
||||
dones[-1] = True
|
||||
# Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors.
|
||||
ret = {
|
||||
"action": torch.stack(all_actions, dim=1),
|
||||
"next.reward": torch.stack(all_rewards, dim=1),
|
||||
"next.success": torch.stack(all_successes, dim=1),
|
||||
"done": dones,
|
||||
}
|
||||
|
||||
listener.stop()
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def eval_policy(
|
||||
robot: Robot,
|
||||
policy: torch.nn.Module,
|
||||
fps: float,
|
||||
n_episodes: int,
|
||||
control_time_s: int = 20,
|
||||
use_amp: bool = True,
|
||||
display_cameras: bool = False,
|
||||
reward_classifier_pretrained_path: str | None = None,
|
||||
reward_classifier_config_file: str | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Args:
|
||||
env: The batch of environments.
|
||||
policy: The policy.
|
||||
n_episodes: The number of episodes to evaluate.
|
||||
Returns:
|
||||
Dictionary with metrics and data regarding the rollouts.
|
||||
"""
|
||||
# TODO (michel-aractingi) comment this out for testing with a fixed policy
|
||||
# assert isinstance(policy, Policy)
|
||||
# policy.eval()
|
||||
|
||||
sum_rewards = []
|
||||
max_rewards = []
|
||||
successes = []
|
||||
rollouts = []
|
||||
|
||||
start_eval = time.perf_counter()
|
||||
progbar = trange(n_episodes, desc="Evaluating policy on real robot")
|
||||
reward_classifier = get_classifier(reward_classifier_pretrained_path, reward_classifier_config_file)
|
||||
|
||||
for _ in progbar:
|
||||
rollout_data = rollout(
|
||||
robot,
|
||||
policy,
|
||||
reward_classifier,
|
||||
fps,
|
||||
control_time_s,
|
||||
use_amp,
|
||||
display_cameras,
|
||||
)
|
||||
|
||||
rollouts.append(rollout_data)
|
||||
sum_rewards.append(sum(rollout_data["next.reward"]))
|
||||
max_rewards.append(max(rollout_data["next.reward"]))
|
||||
successes.append(rollout_data["next.success"][-1])
|
||||
|
||||
info = {
|
||||
"per_episode": [
|
||||
{
|
||||
"episode_ix": i,
|
||||
"sum_reward": sum_reward,
|
||||
"max_reward": max_reward,
|
||||
"pc_success": success * 100,
|
||||
}
|
||||
for i, (sum_reward, max_reward, success) in enumerate(
|
||||
zip(
|
||||
sum_rewards[:n_episodes],
|
||||
max_rewards[:n_episodes],
|
||||
successes[:n_episodes],
|
||||
strict=False,
|
||||
)
|
||||
)
|
||||
],
|
||||
"aggregated": {
|
||||
"avg_sum_reward": float(np.nanmean(torch.cat(sum_rewards[:n_episodes]))),
|
||||
"avg_max_reward": float(np.nanmean(torch.cat(max_rewards[:n_episodes]))),
|
||||
"pc_success": float(np.nanmean(torch.cat(successes[:n_episodes])) * 100),
|
||||
"eval_s": time.time() - start_eval,
|
||||
"eval_ep_s": (time.time() - start_eval) / n_episodes,
|
||||
},
|
||||
}
|
||||
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# to allow your terminal to monitor keyboard events.
|
||||
events = {}
|
||||
events["exit_early"] = False
|
||||
events["rerecord_episode"] = False
|
||||
events["pause_policy"] = False
|
||||
events["human_intervention_step"] = False
|
||||
|
||||
if is_headless():
|
||||
logging.warning(
|
||||
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||
)
|
||||
listener = None
|
||||
return listener, events
|
||||
|
||||
# Only import pynput if not in a headless environment
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if key == keyboard.Key.right:
|
||||
print("Right arrow key pressed. Exiting loop...")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.space:
|
||||
# check if first space press then pause the policy for the user to get ready
|
||||
# if second space press then the user is ready to start intervention
|
||||
if not events["pause_policy"]:
|
||||
print(
|
||||
"Space key pressed. Human intervention required.\n"
|
||||
"Place the leader in similar pose to the follower and press space again."
|
||||
)
|
||||
events["pause_policy"] = True
|
||||
log_say(
|
||||
"Human intervention stage. Get ready to take over.",
|
||||
play_sounds=True,
|
||||
)
|
||||
else:
|
||||
events["human_intervention_step"] = True
|
||||
print("Space key pressed. Human intervention starting.")
|
||||
log_say("Starting human intervention.", play_sounds=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
|
||||
return listener, events
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument(
|
||||
"--robot-path",
|
||||
type=str,
|
||||
default="lerobot/configs/robot/koch.yaml",
|
||||
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--robot-overrides",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
group.add_argument(
|
||||
"-p",
|
||||
"--pretrained-policy-name-or-path",
|
||||
help=(
|
||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
|
||||
"(useful for debugging). This argument is mutually exclusive with `--config`."
|
||||
),
|
||||
)
|
||||
group.add_argument(
|
||||
"--config",
|
||||
help=(
|
||||
"Path to a yaml config you want to use for initializing a policy from scratch (useful for "
|
||||
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
|
||||
parser.add_argument(
|
||||
"--out-dir",
|
||||
help=(
|
||||
"Where to save the evaluation outputs. If not provided, outputs are saved in "
|
||||
"outputs/eval/{timestamp}_{env_name}_{policy_name}"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--display-cameras",
|
||||
help=("Whether to display the camera feed while the rollout is happening"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-classifier-pretrained-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the pretrained classifier weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-classifier-config-file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to a yaml config file that is necessary to build the reward classifier model.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
|
||||
robot = make_robot(robot_cfg)
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
eval_policy(
|
||||
robot,
|
||||
None,
|
||||
fps=40,
|
||||
n_episodes=2,
|
||||
control_time_s=100,
|
||||
display_cameras=args.display_cameras,
|
||||
reward_classifier_config_file=args.reward_classifier_config_file,
|
||||
reward_classifier_pretrained_path=args.reward_classifier_pretrained_path,
|
||||
)
|
||||
@@ -224,7 +224,7 @@ def act_with_policy(
|
||||
logging.info("make_policy")
|
||||
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy instance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
# TODO: At some point we should just need make sac policy
|
||||
policy: SACPolicy = make_policy(
|
||||
@@ -278,7 +278,7 @@ def act_with_policy(
|
||||
# Increment total steps counter for intervention rate
|
||||
episode_total_steps += 1
|
||||
|
||||
# NOTE: We overide the action if the intervention is True, because the action applied is the intervention action
|
||||
# NOTE: We override the action if the intervention is True, because the action applied is the intervention action
|
||||
if "is_intervention" in info and info["is_intervention"]:
|
||||
# TODO: Check the shape
|
||||
# NOTE: The action space for demonstration before hand is with the full action space
|
||||
|
||||
@@ -269,7 +269,7 @@ if __name__ == "__main__":
|
||||
new_repo_id = args.repo_id + "_cropped_resized"
|
||||
new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
|
||||
|
||||
croped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
cropped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
original_dataset=dataset,
|
||||
crop_params_dict=rois,
|
||||
new_repo_id=new_repo_id,
|
||||
|
||||
@@ -262,7 +262,7 @@ class GamepadController(InputController):
|
||||
elif event.button == 6:
|
||||
self.close_gripper_command = True
|
||||
|
||||
# LT button (7) for openning gripper
|
||||
# LT button (7) for opening gripper
|
||||
elif event.button == 7:
|
||||
self.open_gripper_command = True
|
||||
|
||||
@@ -421,45 +421,44 @@ class GamepadControllerHID(InputController):
|
||||
try:
|
||||
# Read data from the gamepad
|
||||
data = self.device.read(64)
|
||||
if data:
|
||||
# Interpret gamepad data - this will vary by controller model
|
||||
# These offsets are for the Logitech RumblePad 2
|
||||
if len(data) >= 8:
|
||||
# Normalize joystick values from 0-255 to -1.0-1.0
|
||||
self.left_x = (data[1] - 128) / 128.0
|
||||
self.left_y = (data[2] - 128) / 128.0
|
||||
self.right_x = (data[3] - 128) / 128.0
|
||||
self.right_y = (data[4] - 128) / 128.0
|
||||
# Interpret gamepad data - this will vary by controller model
|
||||
# These offsets are for the Logitech RumblePad 2
|
||||
if data and len(data) >= 8:
|
||||
# Normalize joystick values from 0-255 to -1.0-1.0
|
||||
self.left_x = (data[1] - 128) / 128.0
|
||||
self.left_y = (data[2] - 128) / 128.0
|
||||
self.right_x = (data[3] - 128) / 128.0
|
||||
self.right_y = (data[4] - 128) / 128.0
|
||||
|
||||
# Apply deadzone
|
||||
self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x
|
||||
self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y
|
||||
self.right_x = 0 if abs(self.right_x) < self.deadzone else self.right_x
|
||||
self.right_y = 0 if abs(self.right_y) < self.deadzone else self.right_y
|
||||
# Apply deadzone
|
||||
self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x
|
||||
self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y
|
||||
self.right_x = 0 if abs(self.right_x) < self.deadzone else self.right_x
|
||||
self.right_y = 0 if abs(self.right_y) < self.deadzone else self.right_y
|
||||
|
||||
# Parse button states (byte 5 in the Logitech RumblePad 2)
|
||||
buttons = data[5]
|
||||
# Parse button states (byte 5 in the Logitech RumblePad 2)
|
||||
buttons = data[5]
|
||||
|
||||
# Check if RB is pressed then the intervention flag should be set
|
||||
self.intervention_flag = data[6] in [2, 6, 10, 14]
|
||||
# Check if RB is pressed then the intervention flag should be set
|
||||
self.intervention_flag = data[6] in [2, 6, 10, 14]
|
||||
|
||||
# Check if RT is pressed
|
||||
self.open_gripper_command = data[6] in [8, 10, 12]
|
||||
# Check if RT is pressed
|
||||
self.open_gripper_command = data[6] in [8, 10, 12]
|
||||
|
||||
# Check if LT is pressed
|
||||
self.close_gripper_command = data[6] in [4, 6, 12]
|
||||
# Check if LT is pressed
|
||||
self.close_gripper_command = data[6] in [4, 6, 12]
|
||||
|
||||
# Check if Y/Triangle button (bit 7) is pressed for saving
|
||||
# Check if X/Square button (bit 5) is pressed for failure
|
||||
# Check if A/Cross button (bit 4) is pressed for rerecording
|
||||
if buttons & 1 << 7:
|
||||
self.episode_end_status = "success"
|
||||
elif buttons & 1 << 5:
|
||||
self.episode_end_status = "failure"
|
||||
elif buttons & 1 << 4:
|
||||
self.episode_end_status = "rerecord_episode"
|
||||
else:
|
||||
self.episode_end_status = None
|
||||
# Check if Y/Triangle button (bit 7) is pressed for saving
|
||||
# Check if X/Square button (bit 5) is pressed for failure
|
||||
# Check if A/Cross button (bit 4) is pressed for rerecording
|
||||
if buttons & 1 << 7:
|
||||
self.episode_end_status = "success"
|
||||
elif buttons & 1 << 5:
|
||||
self.episode_end_status = "failure"
|
||||
elif buttons & 1 << 4:
|
||||
self.episode_end_status = "rerecord_episode"
|
||||
else:
|
||||
self.episode_end_status = None
|
||||
|
||||
except OSError as e:
|
||||
logging.error(f"Error reading from gamepad: {e}")
|
||||
@@ -618,7 +617,7 @@ def teleoperate_delta_inverse_kinematics(robot, controller, fps=10, bounds=None,
|
||||
# Process input events
|
||||
controller.update()
|
||||
|
||||
# Get currrent robot state
|
||||
# Get current robot state
|
||||
joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
|
||||
@@ -635,7 +634,7 @@ def teleoperate_delta_inverse_kinematics(robot, controller, fps=10, bounds=None,
|
||||
desired_ee_pos[:3, 3] = np.clip(desired_ee_pos[:3, 3], bounds["min"], bounds["max"])
|
||||
|
||||
# Only send commands if there's actual movement
|
||||
if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]):
|
||||
if any(abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]):
|
||||
# Compute joint targets via inverse kinematics
|
||||
target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True)
|
||||
|
||||
@@ -678,7 +677,7 @@ def teleoperate_gym_env(env, controller, fps: int = 30):
|
||||
action = np.array([delta_x, delta_y, delta_z])
|
||||
|
||||
# Skip if no movement
|
||||
if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]):
|
||||
if any(abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]):
|
||||
# Step the environment - pass action as a tensor with intervention flag
|
||||
action_tensor = torch.from_numpy(action.astype(np.float32))
|
||||
obs, reward, terminated, truncated, info = env.step((action_tensor, False))
|
||||
|
||||
@@ -106,7 +106,7 @@ class HILSerlRobotEnv(gym.Env):
|
||||
- The action space is defined as a Tuple where:
|
||||
• The first element is a Box space representing joint position commands. It is defined as relative (delta)
|
||||
or absolute, based on the configuration.
|
||||
• ThE SECONd element is a Discrete space (with 2 values) serving as a flag for intervention (teleoperation).
|
||||
• The second element is a Discrete space (with 2 values) serving as a flag for intervention (teleoperation).
|
||||
"""
|
||||
example_obs = self.robot.capture_observation()
|
||||
|
||||
@@ -384,7 +384,7 @@ class ActionRepeatWrapper(gym.Wrapper):
|
||||
class RewardWrapper(gym.Wrapper):
|
||||
def __init__(self, env, reward_classifier, device: torch.device = "cuda"):
|
||||
"""
|
||||
Wrapper to add reward prediction to the environment, it use a trained classifer.
|
||||
Wrapper to add reward prediction to the environment, it use a trained classifier.
|
||||
|
||||
cfg.
|
||||
env: The environment to wrap
|
||||
@@ -414,7 +414,7 @@ class RewardWrapper(gym.Wrapper):
|
||||
if self.reward_classifier is not None
|
||||
else 0.0
|
||||
)
|
||||
info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
|
||||
info["Reward classifier frequency"] = 1 / (time.perf_counter() - start_time)
|
||||
|
||||
if success == 1.0:
|
||||
terminated = True
|
||||
@@ -784,11 +784,11 @@ class ResetWrapper(gym.Wrapper):
|
||||
while time.perf_counter() - start_time < self.reset_time_s:
|
||||
self.robot.teleop_step()
|
||||
|
||||
log_say("Manual reseting of the environment done.", play_sounds=True)
|
||||
log_say("Manual reset of the environment done.", play_sounds=True)
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
|
||||
class BatchCompitableWrapper(gym.ObservationWrapper):
|
||||
class BatchCompatibleWrapper(gym.ObservationWrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
|
||||
@@ -823,10 +823,7 @@ class GripperPenaltyWrapper(gym.RewardWrapper):
|
||||
|
||||
def step(self, action):
|
||||
self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
|
||||
if isinstance(action, tuple):
|
||||
gripper_action = action[0][-1]
|
||||
else:
|
||||
gripper_action = action[-1]
|
||||
gripper_action = action[0][-1] if isinstance(action, tuple) else action[-1]
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
gripper_penalty = self.reward(reward, gripper_action)
|
||||
|
||||
@@ -1279,7 +1276,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||
)
|
||||
if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
|
||||
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
|
||||
env = BatchCompitableWrapper(env=env)
|
||||
env = BatchCompatibleWrapper(env=env)
|
||||
|
||||
return env
|
||||
|
||||
@@ -1492,7 +1489,7 @@ def main(cfg: EnvConfig):
|
||||
alpha = 1.0
|
||||
|
||||
num_episode = 0
|
||||
sucesses = []
|
||||
successes = []
|
||||
while num_episode < 20:
|
||||
start_loop_s = time.perf_counter()
|
||||
# Sample a new random action from the robot's action space.
|
||||
@@ -1503,15 +1500,15 @@ def main(cfg: EnvConfig):
|
||||
# Execute the step: wrap the NumPy action in a torch tensor.
|
||||
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False))
|
||||
if terminated or truncated:
|
||||
sucesses.append(reward)
|
||||
successes.append(reward)
|
||||
env.reset()
|
||||
num_episode += 1
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_s
|
||||
busy_wait(1 / cfg.fps - dt_s)
|
||||
|
||||
logging.info(f"Success after 20 steps {sucesses}")
|
||||
logging.info(f"success rate {sum(sucesses) / len(sucesses)}")
|
||||
logging.info(f"Success after 20 steps {successes}")
|
||||
logging.info(f"success rate {sum(successes) / len(successes)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -319,7 +319,7 @@ def add_actor_information_and_train(
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
logging.info("Starting learner thread")
|
||||
interaction_message, transition = None, None
|
||||
interaction_message = None
|
||||
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
|
||||
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
|
||||
|
||||
@@ -654,7 +654,7 @@ def start_learner_server(
|
||||
|
||||
shutdown_event.wait()
|
||||
logging.info("[LEARNER] Stopping gRPC server...")
|
||||
server.stop(learner_service.STUTDOWN_TIMEOUT)
|
||||
server.stop(learner_service.SHUTDOWN_TIMEOUT)
|
||||
logging.info("[LEARNER] gRPC server stopped")
|
||||
|
||||
|
||||
@@ -719,7 +719,7 @@ def save_training_checkpoint(
|
||||
# Update the "last" symlink
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
|
||||
# TODO : temporarly save replay buffer here, remove later when on the robot
|
||||
# TODO : temporary save replay buffer here, remove later when on the robot
|
||||
# We want to control this with the keyboard inputs
|
||||
dataset_dir = os.path.join(cfg.output_dir, "dataset")
|
||||
if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir):
|
||||
@@ -889,7 +889,7 @@ def load_training_state(
|
||||
training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt")
|
||||
interaction_step = 0
|
||||
if os.path.exists(training_state_path):
|
||||
training_state = torch.load(training_state_path, weights_only=False)
|
||||
training_state = torch.load(training_state_path, weights_only=False) # nosec B614: Safe usage of torch.load
|
||||
interaction_step = training_state.get("interaction_step", 0)
|
||||
|
||||
logging.info(f"Resuming from step {step}, interaction step {interaction_step}")
|
||||
|
||||
@@ -8,13 +8,13 @@ from lerobot.scripts.server.network_utils import receive_bytes_in_chunks, send_b
|
||||
|
||||
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
|
||||
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
|
||||
STUTDOWN_TIMEOUT = 10
|
||||
SHUTDOWN_TIMEOUT = 10
|
||||
|
||||
|
||||
class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
|
||||
def __init__(
|
||||
self,
|
||||
shutdown_event: Event,
|
||||
shutdown_event: Event, # type: ignore
|
||||
parameters_queue: Queue,
|
||||
seconds_between_pushes: float,
|
||||
transition_queue: Queue,
|
||||
@@ -26,7 +26,7 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
|
||||
self.transition_queue = transition_queue
|
||||
self.interaction_message_queue = interaction_message_queue
|
||||
|
||||
def StreamParameters(self, request, context):
|
||||
def StreamParameters(self, request, context): # noqa: N802
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to stream parameters from the Actor")
|
||||
|
||||
@@ -48,7 +48,7 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
|
||||
logging.info("[LEARNER] Stream parameters finished")
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
def SendTransitions(self, request_iterator, _context):
|
||||
def SendTransitions(self, request_iterator, _context): # noqa: N802
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to receive transitions from the Actor")
|
||||
|
||||
@@ -62,7 +62,7 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
|
||||
logging.debug("[LEARNER] Finished receiving transitions")
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
def SendInteractions(self, request_iterator, _context):
|
||||
def SendInteractions(self, request_iterator, _context): # noqa: N802
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to receive interactions from the Actor")
|
||||
|
||||
@@ -76,5 +76,5 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
|
||||
logging.debug("[LEARNER] Finished receiving interactions")
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
def Ready(self, request, context):
|
||||
def Ready(self, request, context): # noqa: N802
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
@@ -219,50 +219,3 @@ def make_maniskill(
|
||||
env = ManiskillMockGripperWrapper(env, nb_discrete_actions=3)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
# @parser.wrap()
|
||||
# def main(cfg: TrainPipelineConfig):
|
||||
# """Main function to run the ManiSkill environment."""
|
||||
# # Create the ManiSkill environment
|
||||
# env = make_maniskill(cfg.env, n_envs=1)
|
||||
|
||||
# # Reset the environment
|
||||
# obs, info = env.reset()
|
||||
|
||||
# # Run a simple interaction loop
|
||||
# sum_reward = 0
|
||||
# for i in range(100):
|
||||
# # Sample a random action
|
||||
# action = env.action_space.sample()
|
||||
|
||||
# # Step the environment
|
||||
# start_time = time.perf_counter()
|
||||
# obs, reward, terminated, truncated, info = env.step(action)
|
||||
# step_time = time.perf_counter() - start_time
|
||||
# sum_reward += reward
|
||||
# # Log information
|
||||
|
||||
# # Reset if episode terminated
|
||||
# if terminated or truncated:
|
||||
# logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
|
||||
# sum_reward = 0
|
||||
# obs, info = env.reset()
|
||||
|
||||
# # Close the environment
|
||||
# env.close()
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# logging.basicConfig(level=logging.INFO)
|
||||
# main()
|
||||
|
||||
if __name__ == "__main__":
|
||||
import draccus
|
||||
|
||||
config = ManiskillEnvConfig()
|
||||
draccus.set_config_type("json")
|
||||
draccus.dump(
|
||||
config=config,
|
||||
stream=open(file="run_config.json", mode="w"),
|
||||
)
|
||||
|
||||
@@ -53,84 +53,6 @@ from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
|
||||
|
||||
def make_optimizer_and_scheduler(cfg, policy):
|
||||
if cfg.policy.name == "act":
|
||||
optimizer_params_dicts = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in policy.named_parameters()
|
||||
if not n.startswith("model.backbone") and p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in policy.named_parameters()
|
||||
if n.startswith("model.backbone") and p.requires_grad
|
||||
],
|
||||
"lr": cfg.training.lr_backbone,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(
|
||||
optimizer_params_dicts,
|
||||
lr=cfg.training.lr,
|
||||
weight_decay=cfg.training.weight_decay,
|
||||
)
|
||||
lr_scheduler = None
|
||||
elif cfg.policy.name == "diffusion":
|
||||
optimizer = torch.optim.Adam(
|
||||
policy.diffusion.parameters(),
|
||||
cfg.training.lr,
|
||||
cfg.training.adam_betas,
|
||||
cfg.training.adam_eps,
|
||||
cfg.training.adam_weight_decay,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
cfg.training.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=cfg.training.lr_warmup_steps,
|
||||
num_training_steps=cfg.training.offline_steps,
|
||||
)
|
||||
elif policy.name == "tdmpc":
|
||||
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
|
||||
lr_scheduler = None
|
||||
|
||||
elif policy.name == "sac":
|
||||
optimizer = torch.optim.Adam(
|
||||
[
|
||||
{"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
|
||||
{
|
||||
"params": policy.critic_ensemble.parameters(),
|
||||
"lr": policy.config.critic_lr,
|
||||
},
|
||||
{
|
||||
"params": policy.temperature.parameters(),
|
||||
"lr": policy.config.temperature_lr,
|
||||
},
|
||||
]
|
||||
)
|
||||
lr_scheduler = None
|
||||
|
||||
elif cfg.policy.name == "vqbet":
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import (
|
||||
VQBeTOptimizer,
|
||||
VQBeTScheduler,
|
||||
)
|
||||
|
||||
optimizer = VQBeTOptimizer(policy, cfg)
|
||||
lr_scheduler = VQBeTScheduler(optimizer, cfg)
|
||||
elif cfg.policy.name == "hilserl_classifier":
|
||||
optimizer = torch.optim.AdamW(policy.parameters(), cfg.policy.learning_rate)
|
||||
lr_scheduler = None
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return optimizer, lr_scheduler
|
||||
|
||||
|
||||
def update_policy(
|
||||
train_metrics: MetricsTracker,
|
||||
policy: PreTrainedPolicy,
|
||||
|
||||
@@ -1,466 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from pprint import pformat
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import wandb
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch import optim
|
||||
from torch.autograd import profiler
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.scripts.server.buffer import random_shift
|
||||
|
||||
|
||||
def get_model(cfg, logger): # noqa I001
|
||||
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
||||
model = Classifier(classifier_config)
|
||||
if cfg.resume:
|
||||
model.load_state_dict(Classifier.from_pretrained(str(logger.last_pretrained_model_dir)).state_dict())
|
||||
return model
|
||||
|
||||
|
||||
def create_balanced_sampler(dataset, cfg):
|
||||
# Get underlying dataset if using Subset
|
||||
original_dataset = dataset.dataset if isinstance(dataset, torch.utils.data.Subset) else dataset
|
||||
|
||||
# Get indices if using Subset (for slicing)
|
||||
indices = dataset.indices if isinstance(dataset, torch.utils.data.Subset) else None
|
||||
|
||||
# Get labels from Hugging Face dataset
|
||||
if indices is not None:
|
||||
# Get subset of labels using Hugging Face's select()
|
||||
hf_subset = original_dataset.hf_dataset.select(indices)
|
||||
labels = hf_subset[cfg.training.label_key]
|
||||
else:
|
||||
# Get all labels directly
|
||||
labels = original_dataset.hf_dataset[cfg.training.label_key]
|
||||
|
||||
labels = torch.stack(labels)
|
||||
_, counts = torch.unique(labels, return_counts=True)
|
||||
class_weights = 1.0 / counts.float()
|
||||
sample_weights = class_weights[labels]
|
||||
|
||||
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
|
||||
|
||||
|
||||
def support_amp(device: torch.device, cfg: DictConfig) -> bool:
|
||||
# Check if the device supports AMP
|
||||
# Here is an example of the issue that says that MPS doesn't support AMP properply
|
||||
return cfg.training.use_amp and device.type in ("cuda", "cpu")
|
||||
|
||||
|
||||
def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg):
|
||||
# Single epoch training loop with AMP support and progress tracking
|
||||
model.train()
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
pbar = tqdm(train_loader, desc="Training")
|
||||
for batch_idx, batch in enumerate(pbar):
|
||||
start_time = time.perf_counter()
|
||||
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
|
||||
images = [random_shift(img, 4) for img in images]
|
||||
labels = batch[cfg.training.label_key].float().to(device)
|
||||
|
||||
# Forward pass with optional AMP
|
||||
with torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext():
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs.logits, labels)
|
||||
|
||||
# Backward pass with gradient scaling if AMP enabled
|
||||
optimizer.zero_grad()
|
||||
if cfg.training.use_amp:
|
||||
grad_scaler.scale(loss).backward()
|
||||
grad_scaler.step(optimizer)
|
||||
grad_scaler.update()
|
||||
else:
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Track metrics
|
||||
if model.config.num_classes == 2:
|
||||
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
|
||||
else:
|
||||
predictions = torch.argmax(outputs.logits, dim=1)
|
||||
correct += (predictions == labels).sum().item()
|
||||
total += labels.size(0)
|
||||
|
||||
current_acc = 100 * correct / total
|
||||
train_info = {
|
||||
"loss": loss.item(),
|
||||
"accuracy": current_acc,
|
||||
"dataloading_s": time.perf_counter() - start_time,
|
||||
}
|
||||
|
||||
logger.log_dict(train_info, step + batch_idx, mode="train")
|
||||
pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{current_acc:.2f}%"})
|
||||
|
||||
|
||||
def validate(model, val_loader, criterion, device, logger, cfg):
|
||||
# Validation loop with metric tracking and sample logging
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
batch_start_time = time.perf_counter()
|
||||
samples = []
|
||||
running_loss = 0
|
||||
inference_times = []
|
||||
|
||||
with (
|
||||
torch.no_grad(),
|
||||
torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(),
|
||||
):
|
||||
for batch in tqdm(val_loader, desc="Validation"):
|
||||
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
|
||||
labels = batch[cfg.training.label_key].float().to(device)
|
||||
|
||||
if cfg.training.profile_inference_time and logger._cfg.wandb.enable:
|
||||
with (
|
||||
profiler.profile(record_shapes=True) as prof,
|
||||
profiler.record_function("model_inference"),
|
||||
):
|
||||
outputs = model(images)
|
||||
inference_times.append(
|
||||
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
|
||||
)
|
||||
else:
|
||||
outputs = model(images)
|
||||
|
||||
loss = criterion(outputs.logits, labels)
|
||||
|
||||
# Track metrics
|
||||
if model.config.num_classes == 2:
|
||||
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
|
||||
else:
|
||||
predictions = torch.argmax(outputs.logits, dim=1)
|
||||
correct += (predictions == labels).sum().item()
|
||||
total += labels.size(0)
|
||||
running_loss += loss.item()
|
||||
|
||||
# Log sample predictions for visualization
|
||||
if len(samples) < cfg.eval.num_samples_to_log:
|
||||
for i in range(min(cfg.eval.num_samples_to_log - len(samples), len(images))):
|
||||
if model.config.num_classes == 2:
|
||||
confidence = round(outputs.probabilities[i].item(), 3)
|
||||
else:
|
||||
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
|
||||
samples.append(
|
||||
{
|
||||
**{
|
||||
f"image_{img_key}": wandb.Image(images[img_idx][i].cpu())
|
||||
for img_idx, img_key in enumerate(cfg.training.image_keys)
|
||||
},
|
||||
"true_label": labels[i].item(),
|
||||
"predicted": predictions[i].item(),
|
||||
"confidence": confidence,
|
||||
}
|
||||
)
|
||||
|
||||
accuracy = 100 * correct / total
|
||||
avg_loss = running_loss / len(val_loader)
|
||||
print(f"Average validation loss {avg_loss}, and accuracy {accuracy}")
|
||||
|
||||
eval_info = {
|
||||
"loss": avg_loss,
|
||||
"accuracy": accuracy,
|
||||
"eval_s": time.perf_counter() - batch_start_time,
|
||||
"eval/prediction_samples": wandb.Table(
|
||||
data=[list(s.values()) for s in samples],
|
||||
columns=list(samples[0].keys()),
|
||||
)
|
||||
if logger._cfg.wandb.enable
|
||||
else None,
|
||||
}
|
||||
|
||||
if len(inference_times) > 0:
|
||||
eval_info["inference_time_avg"] = np.mean(inference_times)
|
||||
eval_info["inference_time_median"] = np.median(inference_times)
|
||||
eval_info["inference_time_std"] = np.std(inference_times)
|
||||
eval_info["inference_time_batch_size"] = val_loader.batch_size
|
||||
|
||||
print(
|
||||
f"Inference mean time: {eval_info['inference_time_avg']:.2f} us, median: {eval_info['inference_time_median']:.2f} us, std: {eval_info['inference_time_std']:.2f} us, with {len(inference_times)} iterations on {device.type} device, batch size: {eval_info['inference_time_batch_size']}"
|
||||
)
|
||||
|
||||
return accuracy, eval_info
|
||||
|
||||
|
||||
def benchmark_inference_time(model, dataset, logger, cfg, device, step):
|
||||
if not cfg.training.profile_inference_time:
|
||||
return
|
||||
|
||||
iters = cfg.training.profile_inference_time_iters
|
||||
inference_times = []
|
||||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
num_workers=cfg.training.num_workers,
|
||||
sampler=RandomSampler(dataset),
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for _ in tqdm(range(iters), desc="Benchmarking inference time"):
|
||||
x = next(iter(loader))
|
||||
x = [x[img_key].to(device) for img_key in cfg.training.image_keys]
|
||||
|
||||
# Warm up
|
||||
for _ in range(10):
|
||||
_ = model(x)
|
||||
|
||||
# sync the device
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
elif device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
|
||||
with (
|
||||
profiler.profile(record_shapes=True) as prof,
|
||||
profiler.record_function("model_inference"),
|
||||
):
|
||||
_ = model(x)
|
||||
|
||||
inference_times.append(
|
||||
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
|
||||
)
|
||||
|
||||
inference_times = np.array(inference_times)
|
||||
avg, median, std = (
|
||||
inference_times.mean(),
|
||||
np.median(inference_times),
|
||||
inference_times.std(),
|
||||
)
|
||||
print(
|
||||
f"Inference time mean: {avg:.2f} us, median: {median:.2f} us, std: {std:.2f} us, with {iters} iterations on {device.type} device"
|
||||
)
|
||||
if logger._cfg.wandb.enable:
|
||||
logger.log_dict(
|
||||
{
|
||||
"inference_time_benchmark_avg": avg,
|
||||
"inference_time_benchmark_median": median,
|
||||
"inference_time_benchmark_std": std,
|
||||
},
|
||||
step + 1,
|
||||
mode="eval",
|
||||
)
|
||||
|
||||
return avg, median, std
|
||||
|
||||
|
||||
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None) -> None:
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
if job_name is None:
|
||||
raise NotImplementedError()
|
||||
|
||||
# Main training pipeline with support for resuming training
|
||||
init_logging()
|
||||
logging.info(OmegaConf.to_yaml(cfg))
|
||||
|
||||
# Initialize training environment
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
# Setup dataset and dataloaders
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
root=cfg.dataset_root,
|
||||
local_files_only=cfg.local_files_only,
|
||||
)
|
||||
logging.info(f"Dataset size: {len(dataset)}")
|
||||
|
||||
n_total = len(dataset)
|
||||
n_train = int(cfg.train_split_proportion * len(dataset))
|
||||
train_dataset = torch.utils.data.Subset(dataset, range(0, n_train))
|
||||
val_dataset = torch.utils.data.Subset(dataset, range(n_train, n_total))
|
||||
|
||||
sampler = create_balanced_sampler(train_dataset, cfg)
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=cfg.training.batch_size,
|
||||
num_workers=cfg.training.num_workers,
|
||||
sampler=sampler,
|
||||
pin_memory=device.type == "cuda",
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=cfg.eval.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=cfg.training.num_workers,
|
||||
pin_memory=device.type == "cuda",
|
||||
)
|
||||
|
||||
# Resume training if requested
|
||||
step = 0
|
||||
best_val_acc = 0
|
||||
|
||||
if cfg.resume:
|
||||
if not Logger.get_last_checkpoint_dir(out_dir).exists():
|
||||
raise RuntimeError(
|
||||
"You have set resume=True, but there is no model checkpoint in "
|
||||
f"{Logger.get_last_checkpoint_dir(out_dir)}"
|
||||
)
|
||||
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
||||
logging.info(
|
||||
colored(
|
||||
"You have set resume=True, indicating that you wish to resume a run",
|
||||
color="yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
# Load and validate checkpoint configuration
|
||||
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
|
||||
# Check for differences between the checkpoint configuration and provided configuration.
|
||||
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
|
||||
resolve_delta_timestamps(cfg)
|
||||
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
|
||||
# Ignore the `resume` and parameters.
|
||||
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
|
||||
del diff["values_changed"]["root['resume']"]
|
||||
if len(diff) > 0:
|
||||
logging.warning(
|
||||
"At least one difference was detected between the checkpoint configuration and "
|
||||
f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration "
|
||||
"takes precedence.",
|
||||
)
|
||||
# Use the checkpoint config instead of the provided config (but keep `resume` parameter).
|
||||
cfg = checkpoint_cfg
|
||||
cfg.resume = True
|
||||
|
||||
# Initialize model and training components
|
||||
model = get_model(cfg=cfg, logger=logger).to(device)
|
||||
|
||||
optimizer = optim.AdamW(model.parameters(), lr=cfg.training.learning_rate)
|
||||
# Use BCEWithLogitsLoss for binary classification and CrossEntropyLoss for multi-class
|
||||
criterion = nn.BCEWithLogitsLoss() if model.config.num_classes == 2 else nn.CrossEntropyLoss()
|
||||
grad_scaler = GradScaler(enabled=cfg.training.use_amp)
|
||||
|
||||
# Log model parameters
|
||||
num_learnable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in model.parameters())
|
||||
logging.info(f"Learnable parameters: {format_big_number(num_learnable_params)}")
|
||||
logging.info(f"Total parameters: {format_big_number(num_total_params)}")
|
||||
|
||||
if cfg.resume:
|
||||
step = logger.load_last_training_state(optimizer, None)
|
||||
|
||||
# Training loop with validation and checkpointing
|
||||
for epoch in range(cfg.training.num_epochs):
|
||||
logging.info(f"\nEpoch {epoch + 1}/{cfg.training.num_epochs}")
|
||||
|
||||
train_epoch(
|
||||
model,
|
||||
train_loader,
|
||||
criterion,
|
||||
optimizer,
|
||||
grad_scaler,
|
||||
device,
|
||||
logger,
|
||||
step,
|
||||
cfg,
|
||||
)
|
||||
|
||||
# Periodic validation
|
||||
if cfg.training.eval_freq > 0 and (epoch + 1) % cfg.training.eval_freq == 0:
|
||||
val_acc, eval_info = validate(
|
||||
model,
|
||||
val_loader,
|
||||
criterion,
|
||||
device,
|
||||
logger,
|
||||
cfg,
|
||||
)
|
||||
logger.log_dict(eval_info, step + len(train_loader), mode="eval")
|
||||
|
||||
# Save best model
|
||||
if val_acc > best_val_acc:
|
||||
best_val_acc = val_acc
|
||||
logger.save_checkpoint(
|
||||
train_step=step + len(train_loader),
|
||||
policy=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=None,
|
||||
identifier="best",
|
||||
)
|
||||
|
||||
# Periodic checkpointing
|
||||
if cfg.training.save_checkpoint and (epoch + 1) % cfg.training.save_freq == 0:
|
||||
logger.save_checkpoint(
|
||||
train_step=step + len(train_loader),
|
||||
policy=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=None,
|
||||
identifier=f"{epoch + 1:06d}",
|
||||
)
|
||||
|
||||
step += len(train_loader)
|
||||
|
||||
benchmark_inference_time(model, dataset, logger, cfg, device, step)
|
||||
|
||||
logging.info("Training completed")
|
||||
|
||||
|
||||
@hydra.main(
|
||||
version_base="1.2",
|
||||
config_name="hilserl_classifier",
|
||||
config_path="../configs/policy",
|
||||
)
|
||||
def train_cli(cfg: dict):
|
||||
train(
|
||||
cfg,
|
||||
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
||||
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
||||
)
|
||||
|
||||
|
||||
def train_notebook(
|
||||
out_dir=None,
|
||||
job_name=None,
|
||||
config_name="hilserl_classifier",
|
||||
config_path="../configs/policy",
|
||||
):
|
||||
from hydra import compose, initialize
|
||||
|
||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||
initialize(config_path=config_path)
|
||||
cfg = compose(config_name=config_name)
|
||||
train(cfg, out_dir=out_dir, job_name=job_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_cli()
|
||||
@@ -1,594 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import functools
|
||||
import logging
|
||||
import random
|
||||
from pprint import pformat
|
||||
from typing import Callable, Optional, Sequence, TypedDict
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
|
||||
# TODO: Remove the import of maniskill
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.envs.factory import make_maniskill_env
|
||||
from lerobot.common.envs.utils import preprocess_maniskill_observation
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
init_logging,
|
||||
set_global_seed,
|
||||
)
|
||||
|
||||
|
||||
def make_optimizers_and_scheduler(cfg, policy):
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
|
||||
params=policy.actor.parameters_to_optimize,
|
||||
lr=policy.config.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
|
||||
)
|
||||
# We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
|
||||
lr_scheduler = None
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
"critic": optimizer_critic,
|
||||
"temperature": optimizer_temperature,
|
||||
}
|
||||
return optimizers, lr_scheduler
|
||||
|
||||
|
||||
class Transition(TypedDict):
|
||||
state: dict[str, torch.Tensor]
|
||||
action: torch.Tensor
|
||||
reward: float
|
||||
next_state: dict[str, torch.Tensor]
|
||||
done: bool
|
||||
complementary_info: dict[str, torch.Tensor] = None
|
||||
|
||||
|
||||
class BatchTransition(TypedDict):
|
||||
state: dict[str, torch.Tensor]
|
||||
action: torch.Tensor
|
||||
reward: torch.Tensor
|
||||
next_state: dict[str, torch.Tensor]
|
||||
done: torch.Tensor
|
||||
|
||||
|
||||
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
|
||||
"""
|
||||
Perform a per-image random crop over a batch of images in a vectorized way.
|
||||
(Same as shown previously.)
|
||||
"""
|
||||
B, C, H, W = images.shape
|
||||
crop_h, crop_w = output_size
|
||||
|
||||
if crop_h > H or crop_w > W:
|
||||
raise ValueError(
|
||||
f"Requested crop size ({crop_h}, {crop_w}) is bigger than the image size ({H}, {W})."
|
||||
)
|
||||
|
||||
tops = torch.randint(0, H - crop_h + 1, (B,), device=images.device)
|
||||
lefts = torch.randint(0, W - crop_w + 1, (B,), device=images.device)
|
||||
|
||||
rows = torch.arange(crop_h, device=images.device).unsqueeze(0) + tops.unsqueeze(1)
|
||||
cols = torch.arange(crop_w, device=images.device).unsqueeze(0) + lefts.unsqueeze(1)
|
||||
|
||||
rows = rows.unsqueeze(2).expand(-1, -1, crop_w) # (B, crop_h, crop_w)
|
||||
cols = cols.unsqueeze(1).expand(-1, crop_h, -1) # (B, crop_h, crop_w)
|
||||
|
||||
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
|
||||
|
||||
# Gather pixels
|
||||
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
|
||||
# cropped_hwcn => (B, crop_h, crop_w, C)
|
||||
|
||||
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
|
||||
return cropped
|
||||
|
||||
|
||||
def random_shift(images: torch.Tensor, pad: int = 4):
|
||||
"""Vectorized random shift, imgs: (B,C,H,W), pad: #pixels"""
|
||||
_, _, h, w = images.shape
|
||||
images = F.pad(input=images, pad=(pad, pad, pad, pad), mode="replicate")
|
||||
return random_crop_vectorized(images=images, output_size=(h, w))
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
def __init__(
|
||||
self,
|
||||
capacity: int,
|
||||
device: str = "cuda:0",
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
image_augmentation_function: Optional[Callable] = None,
|
||||
use_drq: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
capacity (int): Maximum number of transitions to store in the buffer.
|
||||
device (str): The device where the tensors will be moved ("cuda:0" or "cpu").
|
||||
state_keys (List[str]): The list of keys that appear in `state` and `next_state`.
|
||||
image_augmentation_function (Optional[Callable]): A function that takes a batch of images
|
||||
and returns a batch of augmented images. If None, a default augmentation function is used.
|
||||
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
|
||||
"""
|
||||
self.capacity = capacity
|
||||
self.device = device
|
||||
self.memory: list[Transition] = []
|
||||
self.position = 0
|
||||
|
||||
# If no state_keys provided, default to an empty list
|
||||
# (you can handle this differently if needed)
|
||||
self.state_keys = state_keys if state_keys is not None else []
|
||||
if image_augmentation_function is None:
|
||||
self.image_augmentation_function = functools.partial(random_shift, pad=4)
|
||||
self.use_drq = use_drq
|
||||
|
||||
def add(
|
||||
self,
|
||||
state: dict[str, torch.Tensor],
|
||||
action: torch.Tensor,
|
||||
reward: float,
|
||||
next_state: dict[str, torch.Tensor],
|
||||
done: bool,
|
||||
complementary_info: Optional[dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Saves a transition."""
|
||||
if len(self.memory) < self.capacity:
|
||||
self.memory.append(None)
|
||||
|
||||
# Create and store the Transition
|
||||
self.memory[self.position] = Transition(
|
||||
state=state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=done,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
self.position: int = (self.position + 1) % self.capacity
|
||||
|
||||
# TODO: ADD image_augmentation and use_drq arguments in this function in order to instantiate the class with them
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
cls,
|
||||
lerobot_dataset: LeRobotDataset,
|
||||
device: str = "cuda:0",
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
) -> "ReplayBuffer":
|
||||
"""
|
||||
Convert a LeRobotDataset into a ReplayBuffer.
|
||||
|
||||
Args:
|
||||
lerobot_dataset (LeRobotDataset): The dataset to convert.
|
||||
device (str): The device . Defaults to "cuda:0".
|
||||
state_keys (Optional[Sequence[str]], optional): The list of keys that appear in `state` and `next_state`.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
ReplayBuffer: The replay buffer with offline dataset transitions.
|
||||
"""
|
||||
# We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from
|
||||
# a replay buffer than from a lerobot dataset.
|
||||
replay_buffer = cls(capacity=len(lerobot_dataset), device=device, state_keys=state_keys)
|
||||
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
|
||||
# Fill the replay buffer with the lerobot dataset transitions
|
||||
for data in list_transition:
|
||||
replay_buffer.add(
|
||||
state=data["state"],
|
||||
action=data["action"],
|
||||
reward=data["reward"],
|
||||
next_state=data["next_state"],
|
||||
done=data["done"],
|
||||
)
|
||||
return replay_buffer
|
||||
|
||||
@staticmethod
|
||||
def _lerobotdataset_to_transitions(
|
||||
dataset: LeRobotDataset,
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
) -> list[Transition]:
|
||||
"""
|
||||
Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions.
|
||||
|
||||
Args:
|
||||
dataset (LeRobotDataset):
|
||||
The dataset to convert. Each item in the dataset is expected to have
|
||||
at least the following keys:
|
||||
{
|
||||
"action": ...
|
||||
"next.reward": ...
|
||||
"next.done": ...
|
||||
"episode_index": ...
|
||||
}
|
||||
plus whatever your 'state_keys' specify.
|
||||
|
||||
state_keys (Optional[Sequence[str]]):
|
||||
The dataset keys to include in 'state' and 'next_state'. Their names
|
||||
will be kept as-is in the output transitions. E.g.
|
||||
["observation.state", "observation.environment_state"].
|
||||
If None, you must handle or define default keys.
|
||||
|
||||
Returns:
|
||||
transitions (List[Transition]):
|
||||
A list of Transition dictionaries with the same length as `dataset`.
|
||||
"""
|
||||
|
||||
# If not provided, you can either raise an error or define a default:
|
||||
if state_keys is None:
|
||||
raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.")
|
||||
|
||||
transitions: list[Transition] = []
|
||||
num_frames = len(dataset)
|
||||
|
||||
for i in tqdm(range(num_frames)):
|
||||
current_sample = dataset[i]
|
||||
|
||||
# ----- 1) Current state -----
|
||||
current_state: dict[str, torch.Tensor] = {}
|
||||
for key in state_keys:
|
||||
val = current_sample[key]
|
||||
current_state[key] = val.unsqueeze(0) # Add batch dimension
|
||||
|
||||
# ----- 2) Action -----
|
||||
action = current_sample["action"].unsqueeze(0) # Add batch dimension
|
||||
|
||||
# ----- 3) Reward and done -----
|
||||
reward = float(current_sample["next.reward"].item()) # ensure float
|
||||
done = bool(current_sample["next.done"].item()) # ensure bool
|
||||
|
||||
# ----- 4) Next state -----
|
||||
# If not done and the next sample is in the same episode, we pull the next sample's state.
|
||||
# Otherwise (done=True or next sample crosses to a new episode), next_state = current_state.
|
||||
next_state = current_state # default
|
||||
if not done and (i < num_frames - 1):
|
||||
next_sample = dataset[i + 1]
|
||||
if next_sample["episode_index"] == current_sample["episode_index"]:
|
||||
# Build next_state from the same keys
|
||||
next_state_data: dict[str, torch.Tensor] = {}
|
||||
for key in state_keys:
|
||||
val = next_sample[key]
|
||||
next_state_data[key] = val.unsqueeze(0) # Add batch dimension
|
||||
next_state = next_state_data
|
||||
|
||||
# ----- Construct the Transition -----
|
||||
transition = Transition(
|
||||
state=current_state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=done,
|
||||
)
|
||||
transitions.append(transition)
|
||||
|
||||
return transitions
|
||||
|
||||
def sample(self, batch_size: int) -> BatchTransition:
|
||||
"""Sample a random batch of transitions and collate them into batched tensors."""
|
||||
list_of_transitions = random.sample(self.memory, batch_size)
|
||||
|
||||
# -- Build batched states --
|
||||
batch_state = {}
|
||||
for key in self.state_keys:
|
||||
batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to(
|
||||
self.device
|
||||
)
|
||||
if key.startswith("observation.image") and self.use_drq:
|
||||
batch_state[key] = self.image_augmentation_function(batch_state[key])
|
||||
|
||||
# -- Build batched actions --
|
||||
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device)
|
||||
|
||||
# -- Build batched rewards --
|
||||
batch_rewards = torch.tensor([t["reward"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
|
||||
# -- Build batched next states --
|
||||
batch_next_state = {}
|
||||
for key in self.state_keys:
|
||||
batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to(
|
||||
self.device
|
||||
)
|
||||
if key.startswith("observation.image") and self.use_drq:
|
||||
batch_next_state[key] = self.image_augmentation_function(batch_next_state[key])
|
||||
|
||||
# -- Build batched dones --
|
||||
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
|
||||
# Return a BatchTransition typed dict
|
||||
return BatchTransition(
|
||||
state=batch_state,
|
||||
action=batch_actions,
|
||||
reward=batch_rewards,
|
||||
next_state=batch_next_state,
|
||||
done=batch_dones,
|
||||
)
|
||||
|
||||
|
||||
def concatenate_batch_transitions(
|
||||
left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
|
||||
) -> BatchTransition:
|
||||
"""NOTE: Be careful it change the left_batch_transitions in place"""
|
||||
left_batch_transitions["state"] = {
|
||||
key: torch.cat(
|
||||
[
|
||||
left_batch_transitions["state"][key],
|
||||
right_batch_transition["state"][key],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
for key in left_batch_transitions["state"]
|
||||
}
|
||||
left_batch_transitions["action"] = torch.cat(
|
||||
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0
|
||||
)
|
||||
left_batch_transitions["reward"] = torch.cat(
|
||||
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
|
||||
)
|
||||
left_batch_transitions["next_state"] = {
|
||||
key: torch.cat(
|
||||
[
|
||||
left_batch_transitions["next_state"][key],
|
||||
right_batch_transition["next_state"][key],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
for key in left_batch_transitions["next_state"]
|
||||
}
|
||||
left_batch_transitions["done"] = torch.cat(
|
||||
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
|
||||
)
|
||||
return left_batch_transitions
|
||||
|
||||
|
||||
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
if job_name is None:
|
||||
raise NotImplementedError()
|
||||
|
||||
init_logging()
|
||||
logging.info(pformat(OmegaConf.to_container(cfg)))
|
||||
|
||||
# Create an env dedicated to online episodes collection from policy rollout.
|
||||
# online_env = make_env(cfg, n_envs=cfg.training.online_rollout_batch_size)
|
||||
# NOTE: Off policy algorithm are efficient enought to use a single environment
|
||||
logging.info("make_env online")
|
||||
# online_env = make_env(cfg, n_envs=1)
|
||||
# TODO: Remove the import of maniskill and unifiy with make env
|
||||
online_env = make_maniskill_env(cfg, n_envs=1)
|
||||
if cfg.training.eval_freq > 0:
|
||||
logging.info("make_env eval")
|
||||
# eval_env = make_env(cfg, n_envs=1)
|
||||
# TODO: Remove the import of maniskill and unifiy with make env
|
||||
eval_env = make_maniskill_env(cfg, n_envs=1)
|
||||
|
||||
# TODO: Add a way to resume training
|
||||
|
||||
# log metrics to terminal and wandb
|
||||
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("make_policy")
|
||||
# TODO: At some point we should just need make sac policy
|
||||
policy: SACPolicy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
# Hack: But if we do online traning, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
device=device,
|
||||
)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||
|
||||
# TODO: Handle resume
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
log_output_dir(out_dir)
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.training.online_steps=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
obs, info = online_env.reset()
|
||||
|
||||
# HACK for maniskill
|
||||
# obs = preprocess_observation(obs)
|
||||
obs = preprocess_maniskill_observation(obs)
|
||||
obs = {key: obs[key].to(device, non_blocking=True) for key in obs}
|
||||
|
||||
replay_buffer = ReplayBuffer(
|
||||
capacity=cfg.training.online_buffer_capacity,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
)
|
||||
|
||||
batch_size = cfg.training.batch_size
|
||||
|
||||
if cfg.dataset_repo_id is not None:
|
||||
logging.info("make_dataset offline buffer")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
logging.info("Convertion to a offline replay buffer")
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys()
|
||||
)
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
# NOTE: For the moment we will solely handle the case of a single environment
|
||||
sum_reward_episode = 0
|
||||
|
||||
for interaction_step in range(cfg.training.online_steps):
|
||||
# NOTE: At some point we should use a wrapper to handle the observation
|
||||
|
||||
if interaction_step >= cfg.training.online_step_before_learning:
|
||||
action = policy.select_action(batch=obs)
|
||||
next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy())
|
||||
else:
|
||||
action = online_env.action_space.sample()
|
||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
||||
# HACK
|
||||
action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True)
|
||||
|
||||
# HACK: For maniskill
|
||||
# next_obs = preprocess_observation(next_obs)
|
||||
next_obs = preprocess_maniskill_observation(next_obs)
|
||||
next_obs = {key: next_obs[key].to(device, non_blocking=True) for key in obs}
|
||||
sum_reward_episode += float(reward[0])
|
||||
# Because we are using a single environment
|
||||
# we can safely assume that the episode is done
|
||||
if done[0] or truncated[0]:
|
||||
logging.info(f"Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
logger.log_dict({"Sum episode reward": sum_reward_episode}, interaction_step)
|
||||
sum_reward_episode = 0
|
||||
# HACK: This is for maniskill
|
||||
logging.info(
|
||||
f"global step {interaction_step}: episode success: {info['success'].float().item()} \n"
|
||||
)
|
||||
logger.log_dict({"Episode success": info["success"].float().item()}, interaction_step)
|
||||
|
||||
replay_buffer.add(
|
||||
state=obs,
|
||||
action=action,
|
||||
reward=float(reward[0]),
|
||||
next_state=next_obs,
|
||||
done=done[0],
|
||||
)
|
||||
obs = next_obs
|
||||
|
||||
if interaction_step < cfg.training.online_step_before_learning:
|
||||
continue
|
||||
for _ in range(cfg.policy.utd_ratio - 1):
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
if cfg.dataset_repo_id is not None:
|
||||
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
batch = concatenate_batch_transitions(batch, batch_offline)
|
||||
|
||||
actions = batch["action"]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
if cfg.dataset_repo_id is not None:
|
||||
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
batch = concatenate_batch_transitions(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
|
||||
actions = batch["action"]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
training_infos = {}
|
||||
training_infos["loss_critic"] = loss_critic.item()
|
||||
|
||||
if interaction_step % cfg.training.policy_update_freq == 0:
|
||||
# TD3 Trick
|
||||
for _ in range(cfg.training.policy_update_freq):
|
||||
loss_actor = policy.compute_loss_actor(observations=observations)
|
||||
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
|
||||
loss_temperature = policy.compute_loss_temperature(observations=observations)
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
optimizers["temperature"].step()
|
||||
|
||||
training_infos["loss_temperature"] = loss_temperature.item()
|
||||
|
||||
if interaction_step % cfg.training.log_freq == 0:
|
||||
logger.log_dict(training_infos, interaction_step, mode="train")
|
||||
|
||||
policy.update_target_networks()
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||
def train_cli(cfg: dict):
|
||||
train(
|
||||
cfg,
|
||||
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
||||
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
||||
)
|
||||
|
||||
|
||||
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
|
||||
from hydra import compose, initialize
|
||||
|
||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||
initialize(config_path=config_path)
|
||||
cfg = compose(config_name=config_name)
|
||||
train(cfg, out_dir=out_dir, job_name=job_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_cli()
|
||||
@@ -168,7 +168,7 @@ def test_save_multi_optimizer_state(multi_optimizers, tmp_path):
|
||||
save_optimizer_state(multi_optimizers, tmp_path)
|
||||
|
||||
# Verify that directories were created for each optimizer
|
||||
for name in multi_optimizers.keys():
|
||||
for name in multi_optimizers:
|
||||
assert (tmp_path / name).is_dir()
|
||||
assert (tmp_path / name / OPTIMIZER_STATE).is_file()
|
||||
assert (tmp_path / name / OPTIMIZER_PARAM_GROUPS).is_file()
|
||||
@@ -204,7 +204,7 @@ def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers,
|
||||
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
|
||||
|
||||
# Verify state dictionaries match
|
||||
for name in multi_optimizers.keys():
|
||||
for name in multi_optimizers:
|
||||
torch.testing.assert_close(multi_optimizers[name].state_dict(), loaded_optimizers[name].state_dict())
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user