Merge remote-tracking branch 'origin/main' into user/rcadene/2024_09_10_train_aloha

This commit is contained in:
Remi Cadene
2024-10-04 19:08:55 +02:00
20 changed files with 5061 additions and 1681 deletions

View File

@@ -129,7 +129,7 @@ from lerobot.common.datasets.video_utils import encode_video_frames
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot, get_arm_id
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
from lerobot.scripts.eval import get_pretrained_policy_path
from lerobot.scripts.push_dataset_to_hub import (
@@ -177,7 +177,7 @@ def none_or_int(value):
return int(value)
def log_control_info(robot, dt_s, episode_index=None, frame_index=None, fps=None):
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
log_items = []
if episode_index is not None:
log_items.append(f"ep:{episode_index}")
@@ -196,24 +196,26 @@ def log_control_info(robot, dt_s, episode_index=None, frame_index=None, fps=None
# total step time displayed in milliseconds and its frequency
log_dt("dt", dt_s)
for name in robot.leader_arms:
key = f"read_leader_{name}_pos_dt_s"
if key in robot.logs:
log_dt("dtRlead", robot.logs[key])
# TODO(aliberts): move robot-specific logs logic in robot.print_logs()
if not robot.robot_type.startswith("stretch"):
for name in robot.leader_arms:
key = f"read_leader_{name}_pos_dt_s"
if key in robot.logs:
log_dt("dtRlead", robot.logs[key])
for name in robot.follower_arms:
key = f"write_follower_{name}_goal_pos_dt_s"
if key in robot.logs:
log_dt("dtWfoll", robot.logs[key])
for name in robot.follower_arms:
key = f"write_follower_{name}_goal_pos_dt_s"
if key in robot.logs:
log_dt("dtWfoll", robot.logs[key])
key = f"read_follower_{name}_pos_dt_s"
if key in robot.logs:
log_dt("dtRfoll", robot.logs[key])
key = f"read_follower_{name}_pos_dt_s"
if key in robot.logs:
log_dt("dtRfoll", robot.logs[key])
for name in robot.cameras:
key = f"read_camera_{name}_dt_s"
if key in robot.logs:
log_dt(f"dtR{name}", robot.logs[key])
for name in robot.cameras:
key = f"read_camera_{name}_dt_s"
if key in robot.logs:
log_dt(f"dtR{name}", robot.logs[key])
info_str = " ".join(log_items)
logging.info(info_str)
@@ -280,9 +282,8 @@ def stop_workers(workers, frame_queue):
process.join()
########################################################################################
# Control modes
########################################################################################
def has_method(_object: object, method_name: str):
return hasattr(_object, method_name) and callable(getattr(_object, method_name))
def get_available_arms(robot):
@@ -297,7 +298,21 @@ def get_available_arms(robot):
return available_arms
########################################################################################
# Control modes
########################################################################################
@safe_disconnect
def calibrate(robot: Robot, arms: list[str] | None):
# TODO(aliberts): move this code in robots' classes
if robot.robot_type.startswith("stretch"):
if not robot.is_connected:
robot.connect()
if not robot.is_homed():
robot.home()
return
available_arms = get_available_arms(robot)
unknown_arms = [arm_id for arm_id in arms if arm_id not in available_arms]
available_arms_str = " ".join(available_arms)
@@ -332,6 +347,7 @@ def calibrate(robot: Robot, arms: list[str] | None):
print("Calibration is done! You can now teleoperate and record datasets!")
@safe_disconnect
def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | None = None):
# TODO(rcadene): Add option to record logs
if not robot.is_connected:
@@ -353,6 +369,7 @@ def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | Non
break
@safe_disconnect
def record(
robot: Robot,
policy: torch.nn.Module | None = None,
@@ -486,7 +503,11 @@ def record(
timestamp = time.perf_counter() - start_warmup_t
# Save images using a worker process with writer threads to reach high fps (30 and more)
if has_method(robot, "teleop_safety_stop"):
robot.teleop_safety_stop()
# Save images using threads to reach high fps (30 and more)
# Using `with` to exist smoothly if an execption is raised.
num_image_writers = num_image_writers_per_camera * len(robot.cameras)
num_image_writers = max(num_image_writers, 1)
frame_queue = multiprocessing.Queue()
@@ -577,6 +598,10 @@ def record(
exit_early = False
break
# TODO(alibets): allow for teleop during reset
if has_method(robot, "teleop_safety_stop"):
robot.teleop_safety_stop()
if not stop_recording:
# Start resetting env while the executor are finishing
logging.info("Reset the environment")