Extend reward classifier for multiple camera views (#626)
This commit is contained in:
@@ -109,6 +109,7 @@ from lerobot.common.robot_devices.control_utils import (
|
||||
log_control_info,
|
||||
record_episode,
|
||||
reset_environment,
|
||||
reset_follower_position,
|
||||
sanity_check_dataset_name,
|
||||
sanity_check_dataset_robot_compatibility,
|
||||
stop_recording,
|
||||
@@ -205,6 +206,7 @@ def record(
|
||||
num_image_writer_threads_per_camera: int = 4,
|
||||
display_cameras: bool = True,
|
||||
play_sounds: bool = True,
|
||||
reset_follower: bool = False,
|
||||
resume: bool = False,
|
||||
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
|
||||
local_files_only: bool = False,
|
||||
@@ -265,6 +267,9 @@ def record(
|
||||
robot.connect()
|
||||
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
|
||||
|
||||
if reset_follower:
|
||||
initial_position = robot.follower_arms["main"].read("Present_Position")
|
||||
|
||||
# Execute a few seconds without recording to:
|
||||
# 1. teleoperate the robot to move it in starting position if no policy provided,
|
||||
# 2. give times to the robot devices to connect and start synchronizing,
|
||||
@@ -307,6 +312,8 @@ def record(
|
||||
(dataset.num_episodes < num_episodes - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment", play_sounds)
|
||||
if reset_follower:
|
||||
reset_follower_position(robot, initial_position)
|
||||
reset_environment(robot, events, reset_time_s)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
@@ -527,6 +534,12 @@ if __name__ == "__main__":
|
||||
default=0,
|
||||
help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--reset-follower",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Resets the follower to the initial position during while reseting the evironment, this is to avoid having the follower start at an awkward position in the next episode",
|
||||
)
|
||||
|
||||
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
|
||||
parser_replay.add_argument(
|
||||
|
||||
@@ -23,6 +23,15 @@ python lerobot/scripts/eval_on_robot.py \
|
||||
eval.n_episodes=10
|
||||
```
|
||||
|
||||
Test reward classifier with teleoperation (you need to press space to take over)
|
||||
```
|
||||
python lerobot/scripts/eval_on_robot.py \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--reward-classifier-pretrained-path outputs/classifier/checkpoints/best/pretrained_model \
|
||||
--reward-classifier-config-file lerobot/configs/policy/hilserl_classifier.yaml \
|
||||
--display-cameras 1
|
||||
```
|
||||
|
||||
**NOTE** (michel-aractingi): This script is incomplete and it is being prepared
|
||||
for running training on the real robot.
|
||||
"""
|
||||
@@ -30,14 +39,14 @@ for running training on the real robot.
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
from copy import deepcopy
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless
|
||||
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position
|
||||
from lerobot.common.robot_devices.robots.factory import Robot, make_robot
|
||||
from lerobot.common.utils.utils import (
|
||||
init_hydra_config,
|
||||
@@ -46,7 +55,33 @@ from lerobot.common.utils.utils import (
|
||||
)
|
||||
|
||||
|
||||
def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, use_amp: bool = True) -> dict:
|
||||
def get_classifier(pretrained_path, config_path):
|
||||
if pretrained_path is None or config_path is None:
|
||||
return
|
||||
|
||||
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
|
||||
cfg = init_hydra_config(config_path)
|
||||
|
||||
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
||||
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
|
||||
model = Classifier(classifier_config)
|
||||
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
|
||||
model = model.to("mps")
|
||||
return model
|
||||
|
||||
|
||||
def rollout(
|
||||
robot: Robot,
|
||||
policy: Policy,
|
||||
reward_classifier,
|
||||
fps: int,
|
||||
control_time_s: float = 20,
|
||||
use_amp: bool = True,
|
||||
display_cameras: bool = False,
|
||||
) -> dict:
|
||||
"""Run a batched policy rollout on the real robot.
|
||||
|
||||
The return dictionary contains:
|
||||
@@ -70,6 +105,7 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20,
|
||||
Returns:
|
||||
The dictionary described above.
|
||||
"""
|
||||
# TODO (michel-aractingi): Infer the device from policy parameters when policy is added
|
||||
# assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module."
|
||||
# device = get_device_from_parameters(policy)
|
||||
|
||||
@@ -79,25 +115,21 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20,
|
||||
# Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready.
|
||||
# policy.reset()
|
||||
|
||||
# Get observation from real robot
|
||||
# NOTE: sorting to make sure the key sequence is the same during training and testing.
|
||||
observation = robot.capture_observation()
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
image_keys.sort()
|
||||
|
||||
# Calculate reward. TODO (michel-aractingi)
|
||||
# in HIL-SERL it will be with a reward classifier
|
||||
reward = calculate_reward(observation)
|
||||
all_observations = []
|
||||
all_actions = []
|
||||
all_rewards = []
|
||||
all_successes = []
|
||||
|
||||
start_episode_t = time.perf_counter()
|
||||
init_pos = robot.follower_arms["main"].read("Present_Position")
|
||||
timestamp = 0.0
|
||||
while timestamp < control_time_s:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
all_observations.append(deepcopy(observation))
|
||||
# observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
|
||||
|
||||
# Apply the next action.
|
||||
while events["pause_policy"] and not events["human_intervention_step"]:
|
||||
busy_wait(0.5)
|
||||
@@ -109,18 +141,26 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20,
|
||||
else:
|
||||
# explore with policy
|
||||
with torch.inference_mode():
|
||||
# TODO (michel-aractingi) replace this part with policy (predict_action)
|
||||
action = robot.follower_arms["main"].read("Present_Position")
|
||||
action = torch.from_numpy(action)
|
||||
robot.send_action(action)
|
||||
# action = predict_action(observation, policy, device, use_amp)
|
||||
|
||||
observation = robot.capture_observation()
|
||||
# Calculate reward
|
||||
# in HIL-SERL it will be with a reward classifier
|
||||
reward = calculate_reward(observation)
|
||||
images = []
|
||||
for key in image_keys:
|
||||
if display_cameras:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
images.append(observation[key].to("mps"))
|
||||
|
||||
reward = reward_classifier.predict_reward(images) if reward_classifier is not None else 0.0
|
||||
all_rewards.append(reward)
|
||||
|
||||
# print("REWARD : ", reward)
|
||||
|
||||
all_actions.append(action)
|
||||
all_rewards.append(torch.from_numpy(reward))
|
||||
all_successes.append(torch.tensor([False]))
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
@@ -131,7 +171,8 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20,
|
||||
events["human_intervention_step"] = False
|
||||
events["pause_policy"] = False
|
||||
break
|
||||
all_observations.append(deepcopy(observation))
|
||||
|
||||
reset_follower_position(robot, target_position=init_pos)
|
||||
|
||||
dones = torch.tensor([False] * len(all_actions))
|
||||
dones[-1] = True
|
||||
@@ -142,10 +183,6 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20,
|
||||
"next.success": torch.stack(all_successes, dim=1),
|
||||
"done": dones,
|
||||
}
|
||||
stacked_observations = {}
|
||||
for key in all_observations[0]:
|
||||
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
|
||||
ret["observation"] = stacked_observations
|
||||
|
||||
listener.stop()
|
||||
|
||||
@@ -159,6 +196,9 @@ def eval_policy(
|
||||
n_episodes: int,
|
||||
control_time_s: int = 20,
|
||||
use_amp: bool = True,
|
||||
display_cameras: bool = False,
|
||||
reward_classifier_pretrained_path: str | None = None,
|
||||
reward_classifier_config_file: str | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Args:
|
||||
@@ -179,8 +219,12 @@ def eval_policy(
|
||||
|
||||
start_eval = time.perf_counter()
|
||||
progbar = trange(n_episodes, desc="Evaluating policy on real robot")
|
||||
for _batch_idx in progbar:
|
||||
rollout_data = rollout(robot, policy, fps, control_time_s, use_amp)
|
||||
reward_classifier = get_classifier(reward_classifier_pretrained_path, reward_classifier_config_file)
|
||||
|
||||
for _ in progbar:
|
||||
rollout_data = rollout(
|
||||
robot, policy, reward_classifier, fps, control_time_s, use_amp, display_cameras
|
||||
)
|
||||
|
||||
rollouts.append(rollout_data)
|
||||
sum_rewards.append(sum(rollout_data["next.reward"]))
|
||||
@@ -219,15 +263,6 @@ def eval_policy(
|
||||
return info
|
||||
|
||||
|
||||
def calculate_reward(observation):
|
||||
"""
|
||||
Method to calculate reward function in some way.
|
||||
In HIL-SERL this is done through defining a reward classifier
|
||||
"""
|
||||
# reward = reward_classifier(observation)
|
||||
return np.array([0.0])
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
@@ -324,6 +359,21 @@ if __name__ == "__main__":
|
||||
"outputs/eval/{timestamp}_{env_name}_{policy_name}"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--display-cameras", help=("Whether to display the camera feed while the rollout is happening")
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-classifier-pretrained-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the pretrained classifier weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-classifier-config-file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to a yaml config file that is necessary to build the reward classifier model.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -332,4 +382,13 @@ if __name__ == "__main__":
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
eval_policy(robot, None, fps=40, n_episodes=2, control_time_s=100)
|
||||
eval_policy(
|
||||
robot,
|
||||
None,
|
||||
fps=40,
|
||||
n_episodes=2,
|
||||
control_time_s=100,
|
||||
display_cameras=args.display_cameras,
|
||||
reward_classifier_config_file=args.reward_classifier_config_file,
|
||||
reward_classifier_pretrained_path=args.reward_classifier_pretrained_path,
|
||||
)
|
||||
|
||||
@@ -22,6 +22,7 @@ from pprint import pformat
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import wandb
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
@@ -30,7 +31,6 @@ from torch.cuda.amp import GradScaler
|
||||
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
|
||||
from tqdm import tqdm
|
||||
|
||||
import wandb
|
||||
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.logger import Logger
|
||||
@@ -79,7 +79,7 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device,
|
||||
pbar = tqdm(train_loader, desc="Training")
|
||||
for batch_idx, batch in enumerate(pbar):
|
||||
start_time = time.perf_counter()
|
||||
images = batch[cfg.training.image_key].to(device)
|
||||
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
|
||||
labels = batch[cfg.training.label_key].float().to(device)
|
||||
|
||||
# Forward pass with optional AMP
|
||||
@@ -130,7 +130,7 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
|
||||
torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(),
|
||||
):
|
||||
for batch in tqdm(val_loader, desc="Validation"):
|
||||
images = batch[cfg.training.image_key].to(device)
|
||||
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
|
||||
labels = batch[cfg.training.label_key].float().to(device)
|
||||
|
||||
outputs = model(images)
|
||||
@@ -163,6 +163,7 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
|
||||
|
||||
accuracy = 100 * correct / total
|
||||
avg_loss = running_loss / len(val_loader)
|
||||
print(f"Average validation loss {avg_loss}, and accuracy {accuracy}")
|
||||
|
||||
eval_info = {
|
||||
"loss": avg_loss,
|
||||
|
||||
Reference in New Issue
Block a user