Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots

This commit is contained in:
Simon Alibert
2025-04-14 15:30:35 +02:00
16 changed files with 177 additions and 67 deletions

View File

@@ -24,7 +24,7 @@ from contextlib import nullcontext
from copy import copy
from functools import cache
import cv2
import rerun as rr
import torch
from deepdiff import DeepDiff
from termcolor import colored
@@ -174,13 +174,13 @@ def warmup_record(
events,
enable_teleoperation,
warmup_time_s,
display_cameras,
display_data,
fps,
):
control_loop(
robot=robot,
control_time_s=warmup_time_s,
display_cameras=display_cameras,
display_data=display_data,
events=events,
fps=fps,
teleoperate=enable_teleoperation,
@@ -192,7 +192,7 @@ def record_episode(
dataset,
events,
episode_time_s,
display_cameras,
display_data,
policy,
fps,
single_task,
@@ -200,7 +200,7 @@ def record_episode(
control_loop(
robot=robot,
control_time_s=episode_time_s,
display_cameras=display_cameras,
display_data=display_data,
dataset=dataset,
events=events,
policy=policy,
@@ -215,7 +215,7 @@ def control_loop(
robot,
control_time_s=None,
teleoperate=False,
display_cameras=False,
display_data=False,
dataset: LeRobotDataset | None = None,
events=None,
policy: PreTrainedPolicy = None,
@@ -264,11 +264,15 @@ def control_loop(
frame = {**observation, **action, "task": single_task}
dataset.add_frame(frame)
if display_cameras and not is_headless():
# TODO(Steven): This should be more general (for RemoteRobot instead of checking the name, but anyways it will change soon)
if (display_data and not is_headless()) or (display_data and robot.robot_type.startswith("lekiwi")):
for k, v in action.items():
for i, vv in enumerate(v):
rr.log(f"sent_{k}_{i}", rr.Scalar(vv.numpy()))
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
rr.log(key, rr.Image(observation[key].numpy()), static=True)
if fps is not None:
dt_s = time.perf_counter() - start_loop_t
@@ -297,15 +301,11 @@ def reset_environment(robot, events, reset_time_s, fps):
)
def stop_recording(robot, listener, display_cameras):
def stop_recording(robot, listener, display_data):
robot.disconnect()
if not is_headless():
if listener is not None:
listener.stop()
if display_cameras:
cv2.destroyAllWindows()
if not is_headless() and listener is not None:
listener.stop()
def sanity_check_dataset_name(repo_id, policy_cfg):