Cleanup control_utils
This commit is contained in:
@@ -18,25 +18,20 @@
|
|||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
import traceback
|
import traceback
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from functools import cache
|
from functools import cache
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import rerun as rr
|
|
||||||
import torch
|
import torch
|
||||||
from deepdiff import DeepDiff
|
from deepdiff import DeepDiff
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from lerobot.common.datasets.image_writer import safe_stop_image_writer
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.datasets.utils import DEFAULT_FEATURES
|
from lerobot.common.datasets.utils import DEFAULT_FEATURES
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.common.robots import Robot
|
from lerobot.common.robots import Robot
|
||||||
from lerobot.common.utils.robot_utils import busy_wait
|
|
||||||
from lerobot.common.utils.utils import get_safe_torch_device, has_method
|
|
||||||
|
|
||||||
|
|
||||||
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
|
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
|
||||||
@@ -173,152 +168,6 @@ def init_keyboard_listener():
|
|||||||
return listener, events
|
return listener, events
|
||||||
|
|
||||||
|
|
||||||
def warmup_record(
|
|
||||||
robot,
|
|
||||||
events,
|
|
||||||
enable_teleoperation,
|
|
||||||
warmup_time_s,
|
|
||||||
display_data,
|
|
||||||
fps,
|
|
||||||
):
|
|
||||||
control_loop(
|
|
||||||
robot=robot,
|
|
||||||
control_time_s=warmup_time_s,
|
|
||||||
display_data=display_data,
|
|
||||||
events=events,
|
|
||||||
fps=fps,
|
|
||||||
teleoperate=enable_teleoperation,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def record_episode(
|
|
||||||
robot,
|
|
||||||
dataset,
|
|
||||||
events,
|
|
||||||
episode_time_s,
|
|
||||||
display_data,
|
|
||||||
policy,
|
|
||||||
fps,
|
|
||||||
single_task,
|
|
||||||
):
|
|
||||||
control_loop(
|
|
||||||
robot=robot,
|
|
||||||
control_time_s=episode_time_s,
|
|
||||||
display_data=display_data,
|
|
||||||
dataset=dataset,
|
|
||||||
events=events,
|
|
||||||
policy=policy,
|
|
||||||
fps=fps,
|
|
||||||
teleoperate=policy is None,
|
|
||||||
single_task=single_task,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@safe_stop_image_writer
|
|
||||||
def control_loop(
|
|
||||||
robot,
|
|
||||||
control_time_s=None,
|
|
||||||
teleoperate=False,
|
|
||||||
display_data=False,
|
|
||||||
dataset: LeRobotDataset | None = None,
|
|
||||||
events=None,
|
|
||||||
policy: PreTrainedPolicy = None,
|
|
||||||
fps: int | None = None,
|
|
||||||
single_task: str | None = None,
|
|
||||||
):
|
|
||||||
# TODO(rcadene): Add option to record logs
|
|
||||||
if not robot.is_connected:
|
|
||||||
robot.connect()
|
|
||||||
|
|
||||||
if events is None:
|
|
||||||
events = {"exit_early": False}
|
|
||||||
|
|
||||||
if control_time_s is None:
|
|
||||||
control_time_s = float("inf")
|
|
||||||
|
|
||||||
if teleoperate and policy is not None:
|
|
||||||
raise ValueError("When `teleoperate` is True, `policy` should be None.")
|
|
||||||
|
|
||||||
if dataset is not None and single_task is None:
|
|
||||||
raise ValueError("You need to provide a task as argument in `single_task`.")
|
|
||||||
|
|
||||||
if dataset is not None and fps is not None and dataset.fps != fps:
|
|
||||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
|
||||||
|
|
||||||
timestamp = 0
|
|
||||||
start_episode_t = time.perf_counter()
|
|
||||||
|
|
||||||
# Controls starts, if policy is given it needs cleaning up
|
|
||||||
if policy is not None:
|
|
||||||
policy.reset()
|
|
||||||
|
|
||||||
while timestamp < control_time_s:
|
|
||||||
start_loop_t = time.perf_counter()
|
|
||||||
|
|
||||||
if teleoperate:
|
|
||||||
observation, action = robot.teleop_step(record_data=True)
|
|
||||||
else:
|
|
||||||
observation = robot.capture_observation()
|
|
||||||
action = None
|
|
||||||
|
|
||||||
if policy is not None:
|
|
||||||
pred_action = predict_action(
|
|
||||||
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
|
||||||
)
|
|
||||||
# Action can eventually be clipped using `max_relative_target`,
|
|
||||||
# so action actually sent is saved in the dataset.
|
|
||||||
action = robot.send_action(pred_action)
|
|
||||||
action = {"action": action}
|
|
||||||
|
|
||||||
if dataset is not None:
|
|
||||||
frame = {**observation, **action, "task": single_task}
|
|
||||||
dataset.add_frame(frame)
|
|
||||||
|
|
||||||
# TODO(Steven): This should be more general (for RemoteRobot instead of checking the name, but anyways it will change soon)
|
|
||||||
if (display_data and not is_headless()) or (display_data and robot.robot_type.startswith("lekiwi")):
|
|
||||||
if action is not None:
|
|
||||||
for k, v in action.items():
|
|
||||||
for i, vv in enumerate(v):
|
|
||||||
rr.log(f"sent_{k}_{i}", rr.Scalar(vv.numpy()))
|
|
||||||
|
|
||||||
image_keys = [key for key in observation if "image" in key]
|
|
||||||
for key in image_keys:
|
|
||||||
rr.log(key, rr.Image(observation[key].numpy()), static=True)
|
|
||||||
|
|
||||||
if fps is not None:
|
|
||||||
dt_s = time.perf_counter() - start_loop_t
|
|
||||||
busy_wait(1 / fps - dt_s)
|
|
||||||
|
|
||||||
dt_s = time.perf_counter() - start_loop_t
|
|
||||||
log_control_info(robot, dt_s, fps=fps)
|
|
||||||
|
|
||||||
timestamp = time.perf_counter() - start_episode_t
|
|
||||||
if events["exit_early"]:
|
|
||||||
events["exit_early"] = False
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
def reset_environment(robot, events, reset_time_s, fps):
|
|
||||||
# TODO(rcadene): refactor warmup_record and reset_environment
|
|
||||||
if has_method(robot, "teleop_safety_stop"):
|
|
||||||
robot.teleop_safety_stop()
|
|
||||||
|
|
||||||
control_loop(
|
|
||||||
robot=robot,
|
|
||||||
control_time_s=reset_time_s,
|
|
||||||
events=events,
|
|
||||||
fps=fps,
|
|
||||||
teleoperate=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def stop_recording(robot, listener, display_data):
|
|
||||||
robot.disconnect()
|
|
||||||
|
|
||||||
if not is_headless() and listener is not None:
|
|
||||||
listener.stop()
|
|
||||||
|
|
||||||
|
|
||||||
def sanity_check_dataset_name(repo_id, policy_cfg):
|
def sanity_check_dataset_name(repo_id, policy_cfg):
|
||||||
_, dataset_name = repo_id.split("/")
|
_, dataset_name = repo_id.split("/")
|
||||||
# either repo_id doesnt start with "eval_" and there is no policy
|
# either repo_id doesnt start with "eval_" and there is no policy
|
||||||
|
|||||||
@@ -131,8 +131,6 @@ class RecordConfig:
|
|||||||
teleop: TeleoperatorConfig | None = None
|
teleop: TeleoperatorConfig | None = None
|
||||||
# Whether to control the robot with a policy
|
# Whether to control the robot with a policy
|
||||||
policy: PreTrainedConfig | None = None
|
policy: PreTrainedConfig | None = None
|
||||||
# Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.
|
|
||||||
warmup_time_s: int | float = 10
|
|
||||||
# Display all cameras on screen
|
# Display all cameras on screen
|
||||||
display_data: bool = False
|
display_data: bool = False
|
||||||
# Use vocal synthesis to read events.
|
# Use vocal synthesis to read events.
|
||||||
@@ -324,7 +322,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
|||||||
single_task=cfg.dataset.single_task,
|
single_task=cfg.dataset.single_task,
|
||||||
display_data=cfg.display_data,
|
display_data=cfg.display_data,
|
||||||
)
|
)
|
||||||
# reset_environment(robot, events, cfg.dataset.reset_time_s, cfg.dataset.fps)
|
|
||||||
|
|
||||||
if events["rerecord_episode"]:
|
if events["rerecord_episode"]:
|
||||||
log_say("Re-record episode", cfg.play_sounds)
|
log_say("Re-record episode", cfg.play_sounds)
|
||||||
|
|||||||
Reference in New Issue
Block a user