[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-04 13:38:47 +00:00
parent d8a1758122
commit 584cad808e
108 changed files with 3894 additions and 1189 deletions

View File

@@ -46,7 +46,11 @@ import torch
from tqdm import trange
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position
from lerobot.common.robot_devices.control_utils import (
busy_wait,
is_headless,
reset_follower_position,
)
from lerobot.common.robot_devices.robots.factory import Robot, make_robot
from lerobot.common.utils.utils import (
init_hydra_config,
@@ -60,13 +64,19 @@ def get_classifier(pretrained_path, config_path):
return
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
ClassifierConfig,
)
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
)
cfg = init_hydra_config(config_path)
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
classifier_config.num_cameras = len(
cfg.training.image_keys
) # TODO automate these paths
model = Classifier(classifier_config)
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
model = model.to("mps")
@@ -151,11 +161,17 @@ def rollout(
images = []
for key in image_keys:
if display_cameras:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.imshow(
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
)
cv2.waitKey(1)
images.append(observation[key].to("mps"))
reward = reward_classifier.predict_reward(images) if reward_classifier is not None else 0.0
reward = (
reward_classifier.predict_reward(images)
if reward_classifier is not None
else 0.0
)
all_rewards.append(reward)
# print("REWARD : ", reward)
@@ -219,11 +235,19 @@ def eval_policy(
start_eval = time.perf_counter()
progbar = trange(n_episodes, desc="Evaluating policy on real robot")
reward_classifier = get_classifier(reward_classifier_pretrained_path, reward_classifier_config_file)
reward_classifier = get_classifier(
reward_classifier_pretrained_path, reward_classifier_config_file
)
for _ in progbar:
rollout_data = rollout(
robot, policy, reward_classifier, fps, control_time_s, use_amp, display_cameras
robot,
policy,
reward_classifier,
fps,
control_time_s,
use_amp,
display_cameras,
)
rollouts.append(rollout_data)
@@ -289,7 +313,9 @@ def init_keyboard_listener():
print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True
elif key == keyboard.Key.left:
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
print(
"Left arrow key pressed. Exiting loop and rerecord the last episode..."
)
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.space:
@@ -301,7 +327,10 @@ def init_keyboard_listener():
"Place the leader in similar pose to the follower and press space again."
)
events["pause_policy"] = True
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
log_say(
"Human intervention stage. Get ready to take over.",
play_sounds=True,
)
else:
events["human_intervention_step"] = True
print("Space key pressed. Human intervention starting.")
@@ -351,7 +380,9 @@ if __name__ == "__main__":
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
),
)
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
parser.add_argument(
"--revision", help="Optionally provide the Hugging Face Hub revision ID."
)
parser.add_argument(
"--out-dir",
help=(
@@ -360,7 +391,8 @@ if __name__ == "__main__":
),
)
parser.add_argument(
"--display-cameras", help=("Whether to display the camera feed while the rollout is happening")
"--display-cameras",
help=("Whether to display the camera feed while the rollout is happening"),
)
parser.add_argument(
"--reward-classifier-pretrained-path",