445 lines
15 KiB
Python
445 lines
15 KiB
Python
########################################################################################
|
||
# Utilities
|
||
########################################################################################
|
||
|
||
|
||
import logging
|
||
import time
|
||
import traceback
|
||
from contextlib import nullcontext
|
||
from copy import copy
|
||
from functools import cache
|
||
|
||
import cv2
|
||
import torch
|
||
from deepdiff import DeepDiff
|
||
from termcolor import colored
|
||
|
||
from lerobot.common.datasets.image_writer import safe_stop_image_writer
|
||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||
from lerobot.common.datasets.utils import get_features_from_robot
|
||
from lerobot.common.robot_devices.robots.utils import Robot
|
||
from lerobot.common.robot_devices.utils import busy_wait
|
||
from lerobot.common.utils.utils import get_safe_torch_device, has_method
|
||
|
||
|
||
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
|
||
log_items = []
|
||
if episode_index is not None:
|
||
log_items.append(f"ep:{episode_index}")
|
||
if frame_index is not None:
|
||
log_items.append(f"frame:{frame_index}")
|
||
|
||
def log_dt(shortname, dt_val_s):
|
||
nonlocal log_items, fps
|
||
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)"
|
||
if fps is not None:
|
||
actual_fps = 1 / dt_val_s
|
||
if actual_fps < fps - 1:
|
||
info_str = colored(info_str, "yellow")
|
||
log_items.append(info_str)
|
||
|
||
# total step time displayed in milliseconds and its frequency
|
||
log_dt("dt", dt_s)
|
||
|
||
# TODO(aliberts): move robot-specific logs logic in robot.print_logs()
|
||
if not robot.robot_type.startswith(("stretch", "piper")):
|
||
for name in robot.leader_arms:
|
||
key = f"read_leader_{name}_pos_dt_s"
|
||
if key in robot.logs:
|
||
log_dt("dtRlead", robot.logs[key])
|
||
|
||
for name in robot.follower_arms:
|
||
key = f"write_follower_{name}_goal_pos_dt_s"
|
||
if key in robot.logs:
|
||
log_dt("dtWfoll", robot.logs[key])
|
||
|
||
key = f"read_follower_{name}_pos_dt_s"
|
||
if key in robot.logs:
|
||
log_dt("dtRfoll", robot.logs[key])
|
||
|
||
for name in robot.cameras:
|
||
key = f"read_camera_{name}_dt_s"
|
||
if key in robot.logs:
|
||
log_dt(f"dtR{name}", robot.logs[key])
|
||
|
||
info_str = " ".join(log_items)
|
||
logging.info(info_str)
|
||
|
||
|
||
@cache
|
||
def is_headless():
|
||
"""Detects if python is running without a monitor."""
|
||
try:
|
||
import pynput # noqa
|
||
|
||
return False
|
||
except Exception:
|
||
print(
|
||
"Error trying to import pynput. Switching to headless mode. "
|
||
"As a result, the video stream from the cameras won't be shown, "
|
||
"and you won't be able to change the control flow with keyboards. "
|
||
"For more info, see traceback below.\n"
|
||
)
|
||
traceback.print_exc()
|
||
print()
|
||
return True
|
||
|
||
|
||
def predict_action(observation, policy, device, use_amp):
|
||
observation = copy(observation)
|
||
with (
|
||
torch.inference_mode(),
|
||
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||
):
|
||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||
for name in observation:
|
||
if "image" in name:
|
||
observation[name] = observation[name].type(torch.float32) / 255
|
||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||
observation[name] = observation[name].unsqueeze(0)
|
||
observation[name] = observation[name].to(device)
|
||
|
||
# Compute the next action with the policy
|
||
# based on the current observation
|
||
action = policy.select_action(observation)
|
||
|
||
# Remove batch dimension
|
||
action = action.squeeze(0)
|
||
|
||
# Move to cpu, if not already the case
|
||
action = action.to("cpu")
|
||
|
||
return action
|
||
|
||
|
||
def init_keyboard_listener():
|
||
# Allow to exit early while recording an episode or resetting the environment,
|
||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||
# to allow your terminal to monitor keyboard events.
|
||
events = {}
|
||
events["exit_early"] = False
|
||
events["rerecord_episode"] = False
|
||
events["stop_recording"] = False
|
||
events["confirm_save"] = False
|
||
events["discard_episode"] = False
|
||
if is_headless():
|
||
logging.warning(
|
||
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||
)
|
||
listener = None
|
||
return listener, events
|
||
# Only import pynput if not in a headless environment
|
||
from pynput import keyboard
|
||
def on_press(key):
|
||
try:
|
||
if key == keyboard.Key.right:
|
||
print("Right arrow key pressed. Exiting current recording...")
|
||
events["exit_early"] = True
|
||
elif key == keyboard.Key.left:
|
||
print("Left arrow key pressed. Interrupting and preparing to rerecord...")
|
||
events["rerecord_episode"] = True
|
||
events["exit_early"] = True
|
||
elif key == keyboard.Key.esc:
|
||
print("Escape key pressed. Stopping data recording session...")
|
||
events["stop_recording"] = True
|
||
events["exit_early"] = True
|
||
elif key == keyboard.Key.enter:
|
||
print("Enter key pressed. Confirming and saving current episode...")
|
||
events["confirm_save"] = True
|
||
elif key == keyboard.Key.backspace:
|
||
print("Back key pressed. Discarding completed episode...")
|
||
events["discard_episode"] = True
|
||
except Exception as e:
|
||
print(f"Error handling key press: {e}")
|
||
listener = keyboard.Listener(on_press=on_press)
|
||
listener.start()
|
||
return listener, events
|
||
|
||
|
||
def warmup_record(
|
||
robot,
|
||
events,
|
||
enable_teleoperation,
|
||
warmup_time_s,
|
||
display_cameras,
|
||
fps,
|
||
):
|
||
control_loop(
|
||
robot=robot,
|
||
control_time_s=warmup_time_s,
|
||
display_cameras=display_cameras,
|
||
events=events,
|
||
fps=fps,
|
||
teleoperate=enable_teleoperation,
|
||
)
|
||
|
||
|
||
def record_episode(
|
||
robot,
|
||
dataset,
|
||
events,
|
||
episode_time_s,
|
||
display_cameras,
|
||
policy,
|
||
device,
|
||
use_amp,
|
||
fps,
|
||
single_task,
|
||
):
|
||
control_loop(
|
||
robot=robot,
|
||
control_time_s=episode_time_s,
|
||
display_cameras=display_cameras,
|
||
dataset=dataset,
|
||
events=events,
|
||
policy=policy,
|
||
device=device,
|
||
use_amp=use_amp,
|
||
fps=fps,
|
||
teleoperate=policy is None,
|
||
single_task=single_task,
|
||
)
|
||
|
||
|
||
@safe_stop_image_writer
|
||
def control_loop(
|
||
robot,
|
||
control_time_s=None,
|
||
teleoperate=False,
|
||
display_cameras=False,
|
||
dataset: LeRobotDataset | None = None,
|
||
events=None,
|
||
policy=None,
|
||
device: torch.device | str | None = None,
|
||
use_amp: bool | None = None,
|
||
fps: int | None = None,
|
||
single_task: str | None = None,
|
||
):
|
||
# TODO(rcadene): Add option to record logs
|
||
if not robot.is_connected:
|
||
robot.connect()
|
||
|
||
if events is None:
|
||
events = {"exit_early": False}
|
||
|
||
if control_time_s is None:
|
||
control_time_s = float("inf")
|
||
|
||
if teleoperate and policy is not None:
|
||
raise ValueError("When `teleoperate` is True, `policy` should be None.")
|
||
|
||
if dataset is not None and single_task is None:
|
||
raise ValueError("You need to provide a task as argument in `single_task`.")
|
||
|
||
if dataset is not None and fps is not None and dataset.fps != fps:
|
||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
||
|
||
if isinstance(device, str):
|
||
device = get_safe_torch_device(device)
|
||
|
||
timestamp = 0
|
||
start_episode_t = time.perf_counter()
|
||
while timestamp < control_time_s:
|
||
start_loop_t = time.perf_counter()
|
||
|
||
if teleoperate:
|
||
observation, action = robot.teleop_step(record_data=True)
|
||
else:
|
||
observation = robot.capture_observation()
|
||
|
||
if policy is not None:
|
||
pred_action = predict_action(observation, policy, device, use_amp)
|
||
# Action can eventually be clipped using `max_relative_target`,
|
||
# so action actually sent is saved in the dataset.
|
||
action = robot.send_action(pred_action)
|
||
action = {"action": action}
|
||
|
||
if dataset is not None:
|
||
frame = {**observation, **action, "task": single_task}
|
||
dataset.add_frame(frame)
|
||
|
||
if display_cameras and not is_headless():
|
||
# image_keys = [key for key in observation if "image" in key]
|
||
# for key in image_keys:
|
||
# cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||
# cv2.waitKey(1)
|
||
display_observations_combined(observation)
|
||
|
||
if fps is not None:
|
||
dt_s = time.perf_counter() - start_loop_t
|
||
busy_wait(1 / fps - dt_s)
|
||
|
||
dt_s = time.perf_counter() - start_loop_t
|
||
log_control_info(robot, dt_s, fps=fps)
|
||
|
||
timestamp = time.perf_counter() - start_episode_t
|
||
if events["exit_early"]:
|
||
events["exit_early"] = False
|
||
break
|
||
|
||
import numpy as np
|
||
def display_observations_combined(observation, display_cameras=True):
|
||
"""将摄像头画面组合在一个窗口中显示,优化以防止抖动"""
|
||
if display_cameras and not is_headless():
|
||
image_keys = [key for key in observation if "image" in key]
|
||
if not image_keys:
|
||
return
|
||
|
||
# 获取所有图像并转换为BGR格式
|
||
images = []
|
||
for key in image_keys:
|
||
img = observation[key].numpy()
|
||
# 确保每个图像尺寸为640×480
|
||
if img.shape[0] != 480 or img.shape[1] != 640:
|
||
img = cv2.resize(img, (640, 480))
|
||
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
||
# 添加图像标题
|
||
cv2.putText(img, key, (10, 30),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
|
||
images.append(img)
|
||
|
||
# 根据实际摄像头数量确定布局
|
||
num_cameras = len(images)
|
||
|
||
# 为固定数量的摄像头定义固定布局
|
||
margin = 10 # 图像之间的间距
|
||
|
||
if num_cameras == 1:
|
||
# 单个摄像头 - 直接显示
|
||
grid_image = images[0]
|
||
elif num_cameras == 2:
|
||
# 两个摄像头 - 水平排列
|
||
grid_width = 2 * 640 + margin
|
||
grid_height = 480
|
||
grid_image = np.zeros((grid_height, grid_width, 3), dtype=np.uint8)
|
||
|
||
# 放置两个图像
|
||
grid_image[0:480, 0:640] = images[0]
|
||
grid_image[0:480, 640+margin:] = images[1]
|
||
elif num_cameras == 3:
|
||
# 三个摄像头 - 上面一行2个,下面一个
|
||
grid_width = 2 * 640 + margin
|
||
grid_height = 2 * 480 + margin
|
||
grid_image = np.zeros((grid_height, grid_width, 3), dtype=np.uint8)
|
||
|
||
# 放置三个图像
|
||
grid_image[0:480, 0:640] = images[0]
|
||
grid_image[0:480, 640+margin:] = images[1]
|
||
grid_image[480+margin:, (grid_width-640)//2:(grid_width+640)//2] = images[2] # 居中放置
|
||
else:
|
||
# 四个或更多摄像头 - 2×2网格
|
||
grid_cols = 2
|
||
grid_rows = 2
|
||
grid_width = grid_cols * 640 + (grid_cols - 1) * margin
|
||
grid_height = grid_rows * 480 + (grid_rows - 1) * margin
|
||
grid_image = np.zeros((grid_height, grid_width, 3), dtype=np.uint8)
|
||
|
||
# 最多显示4个
|
||
for i, img in enumerate(images[:4]):
|
||
row = i // grid_cols
|
||
col = i % grid_cols
|
||
|
||
y_start = row * (480 + margin)
|
||
y_end = y_start + 480
|
||
x_start = col * (640 + margin)
|
||
x_end = x_start + 640
|
||
|
||
grid_image[y_start:y_end, x_start:x_end] = img
|
||
|
||
# 创建一个固定名称的窗口
|
||
window_name = "Camera Views"
|
||
|
||
# 使用 getWindowProperty 检查窗口是否已经存在
|
||
try:
|
||
# 尝试获取窗口属性 - 如果窗口不存在会抛出异常
|
||
window_exists = cv2.getWindowProperty(window_name, cv2.WND_PROP_VISIBLE) >= 0
|
||
except:
|
||
window_exists = False
|
||
|
||
# 如果窗口不存在,创建并定位它
|
||
if not window_exists:
|
||
cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
|
||
|
||
# 设置窗口位置和大小(只需要在创建时设置一次)
|
||
screen_width = 1920 # 您的屏幕宽度
|
||
screen_height = 1080 # 您的屏幕高度
|
||
|
||
# 根据实际图像计算合适的显示尺寸
|
||
scale_factor = min(screen_width / grid_width, screen_height / grid_height) * 0.95
|
||
display_width = int(grid_width * scale_factor)
|
||
display_height = int(grid_height * scale_factor)
|
||
|
||
cv2.resizeWindow(window_name, display_width, display_height)
|
||
cv2.moveWindow(window_name, (screen_width - display_width) // 2,
|
||
(screen_height - display_height) // 2) # 居中显示
|
||
|
||
# 显示图像
|
||
cv2.imshow(window_name, grid_image)
|
||
cv2.waitKey(1)
|
||
|
||
|
||
|
||
def reset_environment(robot, events, reset_time_s, fps):
|
||
# TODO(rcadene): refactor warmup_record and reset_environment
|
||
if has_method(robot, "teleop_safety_stop"):
|
||
robot.teleop_safety_stop()
|
||
|
||
control_loop(
|
||
robot=robot,
|
||
control_time_s=reset_time_s,
|
||
events=events,
|
||
fps=fps,
|
||
teleoperate=True,
|
||
)
|
||
|
||
|
||
def stop_recording(robot, listener, display_cameras):
|
||
robot.disconnect()
|
||
|
||
if not is_headless():
|
||
if listener is not None:
|
||
listener.stop()
|
||
|
||
if display_cameras:
|
||
cv2.destroyAllWindows()
|
||
|
||
|
||
def sanity_check_dataset_name(repo_id, policy_cfg):
|
||
dataset_name = repo_id.split("/")[-1]
|
||
# either repo_id doesnt start with "eval_" and there is no policy
|
||
# or repo_id starts with "eval_" and there is a policy
|
||
|
||
# Check if dataset_name starts with "eval_" but policy is missing
|
||
if dataset_name.startswith("eval_") and policy_cfg is None:
|
||
raise ValueError(
|
||
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})."
|
||
)
|
||
|
||
# Check if dataset_name does not start with "eval_" but policy is provided
|
||
if not dataset_name.startswith("eval_") and policy_cfg is not None:
|
||
raise ValueError(
|
||
f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy_cfg.type})."
|
||
)
|
||
|
||
|
||
def sanity_check_dataset_robot_compatibility(
|
||
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool
|
||
) -> None:
|
||
fields = [
|
||
("robot_type", dataset.meta.robot_type, robot.robot_type),
|
||
("fps", dataset.fps, fps),
|
||
("features", dataset.features, get_features_from_robot(robot, use_videos)),
|
||
]
|
||
|
||
mismatches = []
|
||
for field, dataset_value, present_value in fields:
|
||
diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"])
|
||
if diff:
|
||
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
|
||
|
||
if mismatches:
|
||
raise ValueError(
|
||
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
|
||
)
|