cleaned eval_on_robot.py; readded policy; fixed doc strings
This commit is contained in:
@@ -15,25 +15,34 @@
|
||||
# limitations under the License.
|
||||
"""Evaluate a policy by running rollouts on the real robot and computing metrics.
|
||||
|
||||
Usage examples: evaluate a checkpoint from the LeRobot training script for 10 episodes.
|
||||
This script supports performing human interventions during rollouts.
|
||||
Human interventions allow the user to take control of the robot from the policy
|
||||
and correct its behavior. It is specifically designed for reinforcement learning
|
||||
experiments and HIL-SERL (human-in-the-loop reinforcement learning) methods.
|
||||
|
||||
```
|
||||
python lerobot/scripts/eval_on_robot.py \
|
||||
-p outputs/train/model/checkpoints/005000/pretrained_model \
|
||||
eval.n_episodes=10
|
||||
```
|
||||
### How to Use
|
||||
|
||||
Test reward classifier with teleoperation (you need to press space to take over)
|
||||
To rollout a policy on the robot:
|
||||
```
|
||||
python lerobot/scripts/eval_on_robot.py \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--pretrained-policy-path-or-name path/to/pretrained_model \
|
||||
--policy-config path/to/policy/config.yaml \
|
||||
--display-cameras 1
|
||||
```
|
||||
|
||||
If you trained a reward classifier on your task, you can also evaluate it using this script.
|
||||
You can annotate the collection with a pre-trained reward classifier by running:
|
||||
|
||||
```
|
||||
python lerobot/scripts/eval_on_robot.py \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--pretrained-policy-path-or-name path/to/pretrained_model \
|
||||
--policy-config path/to/policy/config.yaml \
|
||||
--reward-classifier-pretrained-path outputs/classifier/checkpoints/best/pretrained_model \
|
||||
--reward-classifier-config-file lerobot/configs/policy/hilserl_classifier.yaml \
|
||||
--display-cameras 1
|
||||
```
|
||||
|
||||
**NOTE** (michel-aractingi): This script is incomplete and it is being prepared
|
||||
for running training on the real robot.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -46,7 +55,8 @@ import torch
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position, predict_action
|
||||
from lerobot.common.robot_devices.robots.factory import Robot, make_robot
|
||||
from lerobot.common.utils.utils import (
|
||||
init_hydra_config,
|
||||
@@ -69,7 +79,6 @@ def get_classifier(pretrained_path, config_path):
|
||||
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
|
||||
model = Classifier(classifier_config)
|
||||
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
|
||||
model = model.to("mps")
|
||||
return model
|
||||
|
||||
|
||||
@@ -81,48 +90,45 @@ def rollout(
|
||||
control_time_s: float = 20,
|
||||
use_amp: bool = True,
|
||||
display_cameras: bool = False,
|
||||
device: str = "cpu"
|
||||
) -> dict:
|
||||
"""Run a batched policy rollout on the real robot.
|
||||
|
||||
The return dictionary contains:
|
||||
"robot": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation
|
||||
keys. NOTE the that this has an extra sequence element relative to the other keys in the
|
||||
dictionary. This is because an extra observation is included for after the environment is
|
||||
terminated or truncated.
|
||||
"action": A (batch, sequence, action_dim) tensor of actions applied based on the observations (not
|
||||
including the last observations).
|
||||
"reward": A (batch, sequence) tensor of rewards received for applying the actions.
|
||||
"success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon
|
||||
environment termination/truncation).
|
||||
"done": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
|
||||
the first True is followed by True's all the way till the end. This can be used for masking
|
||||
extraneous elements from the sequences above.
|
||||
This function executes a rollout using the provided policy and robot interface,
|
||||
simulating batched interactions for a fixed control duration.
|
||||
|
||||
The returned dictionary contains rollout statistics, which can be used for analysis and debugging.
|
||||
|
||||
Args:
|
||||
robot: The robot class that defines the interface with the real robot.
|
||||
policy: The policy. Must be a PyTorch nn module.
|
||||
"robot": The robot interface for interacting with the real robot hardware.
|
||||
"policy": The policy to execute. Must be a PyTorch `nn.Module` object.
|
||||
"reward_classifier": A module to classify rewards during the rollout.
|
||||
"fps": The control frequency at which the policy is executed.
|
||||
"control_time_s": The total control duration of the rollout in seconds.
|
||||
"use_amp": Whether to use automatic mixed precision (AMP) for policy evaluation.
|
||||
"display_cameras": If True, displays camera streams during the rollout.
|
||||
"device": The device to use for computations (e.g., "cpu", "cuda" or "mps").
|
||||
|
||||
Returns:
|
||||
The dictionary described above.
|
||||
Dictionary of the statisitcs collected during rollouts.
|
||||
"""
|
||||
# TODO (michel-aractingi): Infer the device from policy parameters when policy is added
|
||||
# assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module."
|
||||
# device = get_device_from_parameters(policy)
|
||||
|
||||
# define keyboard listener
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
# Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready.
|
||||
# policy.reset()
|
||||
if policy is not None:
|
||||
policy.reset()
|
||||
|
||||
# NOTE: sorting to make sure the key sequence is the same during training and testing.
|
||||
observation = robot.capture_observation()
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
image_keys.sort()
|
||||
image_keys.sort() # CG{T}
|
||||
|
||||
|
||||
all_actions = []
|
||||
all_rewards = []
|
||||
all_successes = []
|
||||
|
||||
indices_from_policy = []
|
||||
|
||||
start_episode_t = time.perf_counter()
|
||||
init_pos = robot.follower_arms["main"].read("Present_Position")
|
||||
@@ -141,27 +147,32 @@ def rollout(
|
||||
else:
|
||||
# explore with policy
|
||||
with torch.inference_mode():
|
||||
# TODO (michel-aractingi) replace this part with policy (predict_action)
|
||||
action = robot.follower_arms["main"].read("Present_Position")
|
||||
action = torch.from_numpy(action)
|
||||
# TODO (michel-aractingi) in placy temporarly for testing purposes
|
||||
if policy is None:
|
||||
action = robot.follower_arms["main"].read("Present_Position")
|
||||
action = torch.from_numpy(action)
|
||||
indices_from_policy.append(False)
|
||||
else:
|
||||
action = predict_action(observation, policy, device, use_amp)
|
||||
indices_from_policy.append(True)
|
||||
|
||||
robot.send_action(action)
|
||||
# action = predict_action(observation, policy, device, use_amp)
|
||||
observation = robot.capture_observation()
|
||||
|
||||
|
||||
observation = robot.capture_observation()
|
||||
images = []
|
||||
for key in image_keys:
|
||||
if display_cameras:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
images.append(observation[key].to("mps"))
|
||||
images.append(observation[key].to(device))
|
||||
|
||||
reward = reward_classifier.predict_reward(images) if reward_classifier is not None else 0.0
|
||||
|
||||
# TODO send data through the server as soon as you have it
|
||||
|
||||
all_rewards.append(reward)
|
||||
|
||||
# print("REWARD : ", reward)
|
||||
|
||||
all_actions.append(action)
|
||||
all_successes.append(torch.tensor([False]))
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
@@ -180,7 +191,6 @@ def rollout(
|
||||
ret = {
|
||||
"action": torch.stack(all_actions, dim=1),
|
||||
"next.reward": torch.stack(all_rewards, dim=1),
|
||||
"next.success": torch.stack(all_successes, dim=1),
|
||||
"done": dones,
|
||||
}
|
||||
|
||||
@@ -199,14 +209,32 @@ def eval_policy(
|
||||
display_cameras: bool = False,
|
||||
reward_classifier_pretrained_path: str | None = None,
|
||||
reward_classifier_config_file: str | None = None,
|
||||
device: str | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Evaluate a policy on a real robot by running multiple episodes and collecting metrics.
|
||||
|
||||
This function executes rollouts of the specified policy on the robot, computes metrics
|
||||
for the rollouts, and optionally evaluates a reward classifier if provided.
|
||||
|
||||
Args:
|
||||
env: The batch of environments.
|
||||
policy: The policy.
|
||||
n_episodes: The number of episodes to evaluate.
|
||||
"robot": The robot interface used to interact with the real robot hardware.
|
||||
"policy": The policy to be evaluated. Must be a PyTorch neural network module.
|
||||
"fps": Frames per second (control frequency) for running the policy.
|
||||
"n_episodes": The number of episodes to evaluate the policy.
|
||||
"control_time_s": The max duration for each episode in seconds.
|
||||
"use_amp": Whether to use automatic mixed precision (AMP) for policy evaluation.
|
||||
"display_cameras": Whether to display camera streams during rollouts.
|
||||
"reward_classifier_pretrained_path": Path to the pretrained reward classifier.
|
||||
If provided, the reward classifier will be evaluated during rollouts.
|
||||
"reward_classifier_config_file": Path to the configuration file for the reward classifier.
|
||||
Required if `reward_classifier_pretrained_path` is provided.
|
||||
"device": The device for computations (e.g., "cpu", "cuda" or "mps").
|
||||
|
||||
Returns:
|
||||
Dictionary with metrics and data regarding the rollouts.
|
||||
"dict": A dictionary containing the following rollout metrics and data:
|
||||
- "metrics": Evaluation metrics such as cumulative rewards, success rates, etc.
|
||||
- "rollout_data": Detailed data from the rollouts, including observations, actions, rewards, and done flags.
|
||||
"""
|
||||
# TODO (michel-aractingi) comment this out for testing with a fixed policy
|
||||
# assert isinstance(policy, Policy)
|
||||
@@ -214,22 +242,22 @@ def eval_policy(
|
||||
|
||||
sum_rewards = []
|
||||
max_rewards = []
|
||||
successes = []
|
||||
rollouts = []
|
||||
|
||||
start_eval = time.perf_counter()
|
||||
progbar = trange(n_episodes, desc="Evaluating policy on real robot")
|
||||
reward_classifier = get_classifier(reward_classifier_pretrained_path, reward_classifier_config_file)
|
||||
reward_classifier = get_classifier(reward_classifier_pretrained_path, reward_classifier_config_file).to(device)§
|
||||
|
||||
device = get_device_from_parameters(policy) if device is None else device
|
||||
|
||||
for _ in progbar:
|
||||
rollout_data = rollout(
|
||||
robot, policy, reward_classifier, fps, control_time_s, use_amp, display_cameras
|
||||
robot, policy, reward_classifier, fps, control_time_s, use_amp, display_cameras, device
|
||||
)
|
||||
|
||||
rollouts.append(rollout_data)
|
||||
sum_rewards.append(sum(rollout_data["next.reward"]))
|
||||
max_rewards.append(max(rollout_data["next.reward"]))
|
||||
successes.append(rollout_data["next.success"][-1])
|
||||
|
||||
info = {
|
||||
"per_episode": [
|
||||
@@ -237,21 +265,18 @@ def eval_policy(
|
||||
"episode_ix": i,
|
||||
"sum_reward": sum_reward,
|
||||
"max_reward": max_reward,
|
||||
"pc_success": success * 100,
|
||||
}
|
||||
for i, (sum_reward, max_reward, success) in enumerate(
|
||||
for i, (sum_reward, max_reward) in enumerate(
|
||||
zip(
|
||||
sum_rewards[:n_episodes],
|
||||
max_rewards[:n_episodes],
|
||||
successes[:n_episodes],
|
||||
strict=False,
|
||||
)
|
||||
)
|
||||
],
|
||||
"aggregated": {
|
||||
"avg_sum_reward": float(np.nanmean(torch.cat(sum_rewards[:n_episodes]))),
|
||||
"avg_sum_reward": float(np.nanmean(torch.cat(sum_rewards[:n_episodes]))),
|
||||
"avg_max_reward": float(np.nanmean(torch.cat(max_rewards[:n_episodes]))),
|
||||
"pc_success": float(np.nanmean(torch.cat(successes[:n_episodes])) * 100),
|
||||
"eval_s": time.time() - start_eval,
|
||||
"eval_ep_s": (time.time() - start_eval) / n_episodes,
|
||||
},
|
||||
@@ -264,9 +289,18 @@ def eval_policy(
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# to allow your terminal to monitor keyboard events.
|
||||
"""
|
||||
Initialize a keyboard listener for controlling the recording and human intervention process.
|
||||
|
||||
Keyboard controls: (Note that this might require sudo permissions to monitor keyboard events)
|
||||
- Right Arrow Key ('->'): Stops the current recording and exits early, useful for ending an episode
|
||||
and moving the next episode recording.
|
||||
- Left Arrow Key ('<-'): Re-records the current episode, allowing the user to start over.
|
||||
- Space Bar: Controls the human intervention process in three steps:
|
||||
1. First press pauses the policy and prompts the user to position the leader similar to the follower.
|
||||
2. Second press initiates human interventions, allowing teleop control of the robot.
|
||||
3. Third press resumes the policy rollout.
|
||||
"""
|
||||
events = {}
|
||||
events["exit_early"] = False
|
||||
events["rerecord_episode"] = False
|
||||
@@ -302,10 +336,15 @@ def init_keyboard_listener():
|
||||
)
|
||||
events["pause_policy"] = True
|
||||
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
|
||||
else:
|
||||
elif events["pause_policy"] and not events["human_intervention_step"]:
|
||||
events["human_intervention_step"] = True
|
||||
print("Space key pressed. Human intervention starting.")
|
||||
log_say("Starting human intervention.", play_sounds=True)
|
||||
elif events["human_intervention_step"]:
|
||||
events["human_intervention_step"] = False
|
||||
events["pause_policy"] = False
|
||||
print("Space key pressed. Human intervention ending, policy resumes control.")
|
||||
log_say("Policy resuming.", play_sounds=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
Reference in New Issue
Block a user