forked from tangger/lerobot
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
@@ -46,7 +46,11 @@ import torch
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position
|
||||
from lerobot.common.robot_devices.control_utils import (
|
||||
busy_wait,
|
||||
is_headless,
|
||||
reset_follower_position,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.factory import Robot, make_robot
|
||||
from lerobot.common.utils.utils import (
|
||||
init_hydra_config,
|
||||
@@ -60,13 +64,19 @@ def get_classifier(pretrained_path, config_path):
|
||||
return
|
||||
|
||||
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
|
||||
ClassifierConfig,
|
||||
)
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
||||
Classifier,
|
||||
)
|
||||
|
||||
cfg = init_hydra_config(config_path)
|
||||
|
||||
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
||||
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
|
||||
classifier_config.num_cameras = len(
|
||||
cfg.training.image_keys
|
||||
) # TODO automate these paths
|
||||
model = Classifier(classifier_config)
|
||||
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
|
||||
model = model.to("mps")
|
||||
@@ -151,11 +161,17 @@ def rollout(
|
||||
images = []
|
||||
for key in image_keys:
|
||||
if display_cameras:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.imshow(
|
||||
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
|
||||
)
|
||||
cv2.waitKey(1)
|
||||
images.append(observation[key].to("mps"))
|
||||
|
||||
reward = reward_classifier.predict_reward(images) if reward_classifier is not None else 0.0
|
||||
reward = (
|
||||
reward_classifier.predict_reward(images)
|
||||
if reward_classifier is not None
|
||||
else 0.0
|
||||
)
|
||||
all_rewards.append(reward)
|
||||
|
||||
# print("REWARD : ", reward)
|
||||
@@ -219,11 +235,19 @@ def eval_policy(
|
||||
|
||||
start_eval = time.perf_counter()
|
||||
progbar = trange(n_episodes, desc="Evaluating policy on real robot")
|
||||
reward_classifier = get_classifier(reward_classifier_pretrained_path, reward_classifier_config_file)
|
||||
reward_classifier = get_classifier(
|
||||
reward_classifier_pretrained_path, reward_classifier_config_file
|
||||
)
|
||||
|
||||
for _ in progbar:
|
||||
rollout_data = rollout(
|
||||
robot, policy, reward_classifier, fps, control_time_s, use_amp, display_cameras
|
||||
robot,
|
||||
policy,
|
||||
reward_classifier,
|
||||
fps,
|
||||
control_time_s,
|
||||
use_amp,
|
||||
display_cameras,
|
||||
)
|
||||
|
||||
rollouts.append(rollout_data)
|
||||
@@ -289,7 +313,9 @@ def init_keyboard_listener():
|
||||
print("Right arrow key pressed. Exiting loop...")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||
print(
|
||||
"Left arrow key pressed. Exiting loop and rerecord the last episode..."
|
||||
)
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.space:
|
||||
@@ -301,7 +327,10 @@ def init_keyboard_listener():
|
||||
"Place the leader in similar pose to the follower and press space again."
|
||||
)
|
||||
events["pause_policy"] = True
|
||||
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
|
||||
log_say(
|
||||
"Human intervention stage. Get ready to take over.",
|
||||
play_sounds=True,
|
||||
)
|
||||
else:
|
||||
events["human_intervention_step"] = True
|
||||
print("Space key pressed. Human intervention starting.")
|
||||
@@ -351,7 +380,9 @@ if __name__ == "__main__":
|
||||
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
|
||||
parser.add_argument(
|
||||
"--revision", help="Optionally provide the Hugging Face Hub revision ID."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out-dir",
|
||||
help=(
|
||||
@@ -360,7 +391,8 @@ if __name__ == "__main__":
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--display-cameras", help=("Whether to display the camera feed while the rollout is happening")
|
||||
"--display-cameras",
|
||||
help=("Whether to display the camera feed while the rollout is happening"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-classifier-pretrained-path",
|
||||
|
||||
Reference in New Issue
Block a user