Add support for Stretch (hello-robot) (#409)

Co-authored-by: Remi <remi.cadene@huggingface.co>
Co-authored-by: Remi Cadene <re.cadene@gmail.com>
This commit is contained in:
Simon Alibert
2024-10-04 18:56:42 +02:00
committed by GitHub
parent 26f97cfd17
commit 1a343c3591
20 changed files with 5052 additions and 1652 deletions

View File

@@ -128,7 +128,7 @@ from lerobot.common.datasets.video_utils import encode_video_frames
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot, get_arm_id
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
from lerobot.scripts.eval import get_pretrained_policy_path
from lerobot.scripts.push_dataset_to_hub import (
@@ -176,7 +176,7 @@ def none_or_int(value):
return int(value)
def log_control_info(robot, dt_s, episode_index=None, frame_index=None, fps=None):
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
log_items = []
if episode_index is not None:
log_items.append(f"ep:{episode_index}")
@@ -195,24 +195,26 @@ def log_control_info(robot, dt_s, episode_index=None, frame_index=None, fps=None
# total step time displayed in milliseconds and its frequency
log_dt("dt", dt_s)
for name in robot.leader_arms:
key = f"read_leader_{name}_pos_dt_s"
if key in robot.logs:
log_dt("dtRlead", robot.logs[key])
# TODO(aliberts): move robot-specific logs logic in robot.print_logs()
if not robot.robot_type.startswith("stretch"):
for name in robot.leader_arms:
key = f"read_leader_{name}_pos_dt_s"
if key in robot.logs:
log_dt("dtRlead", robot.logs[key])
for name in robot.follower_arms:
key = f"write_follower_{name}_goal_pos_dt_s"
if key in robot.logs:
log_dt("dtWfoll", robot.logs[key])
for name in robot.follower_arms:
key = f"write_follower_{name}_goal_pos_dt_s"
if key in robot.logs:
log_dt("dtWfoll", robot.logs[key])
key = f"read_follower_{name}_pos_dt_s"
if key in robot.logs:
log_dt("dtRfoll", robot.logs[key])
key = f"read_follower_{name}_pos_dt_s"
if key in robot.logs:
log_dt("dtRfoll", robot.logs[key])
for name in robot.cameras:
key = f"read_camera_{name}_dt_s"
if key in robot.logs:
log_dt(f"dtR{name}", robot.logs[key])
for name in robot.cameras:
key = f"read_camera_{name}_dt_s"
if key in robot.logs:
log_dt(f"dtR{name}", robot.logs[key])
info_str = " ".join(log_items)
logging.info(info_str)
@@ -237,9 +239,8 @@ def is_headless():
return True
########################################################################################
# Control modes
########################################################################################
def has_method(_object: object, method_name: str):
return hasattr(_object, method_name) and callable(getattr(_object, method_name))
def get_available_arms(robot):
@@ -254,7 +255,21 @@ def get_available_arms(robot):
return available_arms
########################################################################################
# Control modes
########################################################################################
@safe_disconnect
def calibrate(robot: Robot, arms: list[str] | None):
# TODO(aliberts): move this code in robots' classes
if robot.robot_type.startswith("stretch"):
if not robot.is_connected:
robot.connect()
if not robot.is_homed():
robot.home()
return
available_arms = get_available_arms(robot)
unknown_arms = [arm_id for arm_id in arms if arm_id not in available_arms]
available_arms_str = " ".join(available_arms)
@@ -289,6 +304,7 @@ def calibrate(robot: Robot, arms: list[str] | None):
print("Calibration is done! You can now teleoperate and record datasets!")
@safe_disconnect
def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | None = None):
# TODO(rcadene): Add option to record logs
if not robot.is_connected:
@@ -310,6 +326,7 @@ def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | Non
break
@safe_disconnect
def record(
robot: Robot,
policy: torch.nn.Module | None = None,
@@ -443,6 +460,9 @@ def record(
timestamp = time.perf_counter() - start_warmup_t
if has_method(robot, "teleop_safety_stop"):
robot.teleop_safety_stop()
# Save images using threads to reach high fps (30 and more)
# Using `with` to exist smoothly if an execption is raised.
futures = []
@@ -536,6 +556,10 @@ def record(
exit_early = False
break
# TODO(alibets): allow for teleop during reset
if has_method(robot, "teleop_safety_stop"):
robot.teleop_safety_stop()
if not stop_recording:
# Start resetting env while the executor are finishing
logging.info("Reset the environment")