Files
lerobot_piper/lerobot/common/robot_devices/control_utils.py
tangger 2bcbddbfb6
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled
修改摄像头排布,新增显示窗口合并为一个的函数进行窗口排布。
2025-05-07 21:47:11 +08:00

445 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

########################################################################################
# Utilities
########################################################################################
import logging
import time
import traceback
from contextlib import nullcontext
from copy import copy
from functools import cache
import cv2
import torch
from deepdiff import DeepDiff
from termcolor import colored
from lerobot.common.datasets.image_writer import safe_stop_image_writer
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import get_features_from_robot
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import get_safe_torch_device, has_method
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
log_items = []
if episode_index is not None:
log_items.append(f"ep:{episode_index}")
if frame_index is not None:
log_items.append(f"frame:{frame_index}")
def log_dt(shortname, dt_val_s):
nonlocal log_items, fps
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)"
if fps is not None:
actual_fps = 1 / dt_val_s
if actual_fps < fps - 1:
info_str = colored(info_str, "yellow")
log_items.append(info_str)
# total step time displayed in milliseconds and its frequency
log_dt("dt", dt_s)
# TODO(aliberts): move robot-specific logs logic in robot.print_logs()
if not robot.robot_type.startswith(("stretch", "piper")):
for name in robot.leader_arms:
key = f"read_leader_{name}_pos_dt_s"
if key in robot.logs:
log_dt("dtRlead", robot.logs[key])
for name in robot.follower_arms:
key = f"write_follower_{name}_goal_pos_dt_s"
if key in robot.logs:
log_dt("dtWfoll", robot.logs[key])
key = f"read_follower_{name}_pos_dt_s"
if key in robot.logs:
log_dt("dtRfoll", robot.logs[key])
for name in robot.cameras:
key = f"read_camera_{name}_dt_s"
if key in robot.logs:
log_dt(f"dtR{name}", robot.logs[key])
info_str = " ".join(log_items)
logging.info(info_str)
@cache
def is_headless():
"""Detects if python is running without a monitor."""
try:
import pynput # noqa
return False
except Exception:
print(
"Error trying to import pynput. Switching to headless mode. "
"As a result, the video stream from the cameras won't be shown, "
"and you won't be able to change the control flow with keyboards. "
"For more info, see traceback below.\n"
)
traceback.print_exc()
print()
return True
def predict_action(observation, policy, device, use_amp):
observation = copy(observation)
with (
torch.inference_mode(),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)
# Compute the next action with the policy
# based on the current observation
action = policy.select_action(observation)
# Remove batch dimension
action = action.squeeze(0)
# Move to cpu, if not already the case
action = action.to("cpu")
return action
def init_keyboard_listener():
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
events = {}
events["exit_early"] = False
events["rerecord_episode"] = False
events["stop_recording"] = False
events["confirm_save"] = False
events["discard_episode"] = False
if is_headless():
logging.warning(
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
listener = None
return listener, events
# Only import pynput if not in a headless environment
from pynput import keyboard
def on_press(key):
try:
if key == keyboard.Key.right:
print("Right arrow key pressed. Exiting current recording...")
events["exit_early"] = True
elif key == keyboard.Key.left:
print("Left arrow key pressed. Interrupting and preparing to rerecord...")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
print("Escape key pressed. Stopping data recording session...")
events["stop_recording"] = True
events["exit_early"] = True
elif key == keyboard.Key.enter:
print("Enter key pressed. Confirming and saving current episode...")
events["confirm_save"] = True
elif key == keyboard.Key.backspace:
print("Back key pressed. Discarding completed episode...")
events["discard_episode"] = True
except Exception as e:
print(f"Error handling key press: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
return listener, events
def warmup_record(
robot,
events,
enable_teleoperation,
warmup_time_s,
display_cameras,
fps,
):
control_loop(
robot=robot,
control_time_s=warmup_time_s,
display_cameras=display_cameras,
events=events,
fps=fps,
teleoperate=enable_teleoperation,
)
def record_episode(
robot,
dataset,
events,
episode_time_s,
display_cameras,
policy,
device,
use_amp,
fps,
single_task,
):
control_loop(
robot=robot,
control_time_s=episode_time_s,
display_cameras=display_cameras,
dataset=dataset,
events=events,
policy=policy,
device=device,
use_amp=use_amp,
fps=fps,
teleoperate=policy is None,
single_task=single_task,
)
@safe_stop_image_writer
def control_loop(
robot,
control_time_s=None,
teleoperate=False,
display_cameras=False,
dataset: LeRobotDataset | None = None,
events=None,
policy=None,
device: torch.device | str | None = None,
use_amp: bool | None = None,
fps: int | None = None,
single_task: str | None = None,
):
# TODO(rcadene): Add option to record logs
if not robot.is_connected:
robot.connect()
if events is None:
events = {"exit_early": False}
if control_time_s is None:
control_time_s = float("inf")
if teleoperate and policy is not None:
raise ValueError("When `teleoperate` is True, `policy` should be None.")
if dataset is not None and single_task is None:
raise ValueError("You need to provide a task as argument in `single_task`.")
if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
if isinstance(device, str):
device = get_safe_torch_device(device)
timestamp = 0
start_episode_t = time.perf_counter()
while timestamp < control_time_s:
start_loop_t = time.perf_counter()
if teleoperate:
observation, action = robot.teleop_step(record_data=True)
else:
observation = robot.capture_observation()
if policy is not None:
pred_action = predict_action(observation, policy, device, use_amp)
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset.
action = robot.send_action(pred_action)
action = {"action": action}
if dataset is not None:
frame = {**observation, **action, "task": single_task}
dataset.add_frame(frame)
if display_cameras and not is_headless():
# image_keys = [key for key in observation if "image" in key]
# for key in image_keys:
# cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
# cv2.waitKey(1)
display_observations_combined(observation)
if fps is not None:
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - start_loop_t
log_control_info(robot, dt_s, fps=fps)
timestamp = time.perf_counter() - start_episode_t
if events["exit_early"]:
events["exit_early"] = False
break
import numpy as np
def display_observations_combined(observation, display_cameras=True):
"""将摄像头画面组合在一个窗口中显示,优化以防止抖动"""
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
if not image_keys:
return
# 获取所有图像并转换为BGR格式
images = []
for key in image_keys:
img = observation[key].numpy()
# 确保每个图像尺寸为640×480
if img.shape[0] != 480 or img.shape[1] != 640:
img = cv2.resize(img, (640, 480))
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
# 添加图像标题
cv2.putText(img, key, (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
images.append(img)
# 根据实际摄像头数量确定布局
num_cameras = len(images)
# 为固定数量的摄像头定义固定布局
margin = 10 # 图像之间的间距
if num_cameras == 1:
# 单个摄像头 - 直接显示
grid_image = images[0]
elif num_cameras == 2:
# 两个摄像头 - 水平排列
grid_width = 2 * 640 + margin
grid_height = 480
grid_image = np.zeros((grid_height, grid_width, 3), dtype=np.uint8)
# 放置两个图像
grid_image[0:480, 0:640] = images[0]
grid_image[0:480, 640+margin:] = images[1]
elif num_cameras == 3:
# 三个摄像头 - 上面一行2个下面一个
grid_width = 2 * 640 + margin
grid_height = 2 * 480 + margin
grid_image = np.zeros((grid_height, grid_width, 3), dtype=np.uint8)
# 放置三个图像
grid_image[0:480, 0:640] = images[0]
grid_image[0:480, 640+margin:] = images[1]
grid_image[480+margin:, (grid_width-640)//2:(grid_width+640)//2] = images[2] # 居中放置
else:
# 四个或更多摄像头 - 2×2网格
grid_cols = 2
grid_rows = 2
grid_width = grid_cols * 640 + (grid_cols - 1) * margin
grid_height = grid_rows * 480 + (grid_rows - 1) * margin
grid_image = np.zeros((grid_height, grid_width, 3), dtype=np.uint8)
# 最多显示4个
for i, img in enumerate(images[:4]):
row = i // grid_cols
col = i % grid_cols
y_start = row * (480 + margin)
y_end = y_start + 480
x_start = col * (640 + margin)
x_end = x_start + 640
grid_image[y_start:y_end, x_start:x_end] = img
# 创建一个固定名称的窗口
window_name = "Camera Views"
# 使用 getWindowProperty 检查窗口是否已经存在
try:
# 尝试获取窗口属性 - 如果窗口不存在会抛出异常
window_exists = cv2.getWindowProperty(window_name, cv2.WND_PROP_VISIBLE) >= 0
except:
window_exists = False
# 如果窗口不存在,创建并定位它
if not window_exists:
cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
# 设置窗口位置和大小(只需要在创建时设置一次)
screen_width = 1920 # 您的屏幕宽度
screen_height = 1080 # 您的屏幕高度
# 根据实际图像计算合适的显示尺寸
scale_factor = min(screen_width / grid_width, screen_height / grid_height) * 0.95
display_width = int(grid_width * scale_factor)
display_height = int(grid_height * scale_factor)
cv2.resizeWindow(window_name, display_width, display_height)
cv2.moveWindow(window_name, (screen_width - display_width) // 2,
(screen_height - display_height) // 2) # 居中显示
# 显示图像
cv2.imshow(window_name, grid_image)
cv2.waitKey(1)
def reset_environment(robot, events, reset_time_s, fps):
# TODO(rcadene): refactor warmup_record and reset_environment
if has_method(robot, "teleop_safety_stop"):
robot.teleop_safety_stop()
control_loop(
robot=robot,
control_time_s=reset_time_s,
events=events,
fps=fps,
teleoperate=True,
)
def stop_recording(robot, listener, display_cameras):
robot.disconnect()
if not is_headless():
if listener is not None:
listener.stop()
if display_cameras:
cv2.destroyAllWindows()
def sanity_check_dataset_name(repo_id, policy_cfg):
dataset_name = repo_id.split("/")[-1]
# either repo_id doesnt start with "eval_" and there is no policy
# or repo_id starts with "eval_" and there is a policy
# Check if dataset_name starts with "eval_" but policy is missing
if dataset_name.startswith("eval_") and policy_cfg is None:
raise ValueError(
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})."
)
# Check if dataset_name does not start with "eval_" but policy is provided
if not dataset_name.startswith("eval_") and policy_cfg is not None:
raise ValueError(
f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy_cfg.type})."
)
def sanity_check_dataset_robot_compatibility(
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool
) -> None:
fields = [
("robot_type", dataset.meta.robot_type, robot.robot_type),
("fps", dataset.fps, fps),
("features", dataset.features, get_features_from_robot(robot, use_videos)),
]
mismatches = []
for field, dataset_value, present_value in fields:
diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"])
if diff:
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
if mismatches:
raise ValueError(
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
)