Enable CI for robot devices with mocked versions (#398)

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Remi
2024-10-03 17:05:23 +02:00
committed by GitHub
parent 72f402d44b
commit 26f97cfd17
18 changed files with 1053 additions and 237 deletions

View File

@@ -242,7 +242,8 @@ def is_headless():
########################################################################################
def calibrate(robot: Robot, arms: list[str] | None):
def get_available_arms(robot):
# TODO(rcadene): moves this function in manipulator class?
available_arms = []
for name in robot.follower_arms:
arm_id = get_arm_id(name, "follower")
@@ -250,9 +251,12 @@ def calibrate(robot: Robot, arms: list[str] | None):
for name in robot.leader_arms:
arm_id = get_arm_id(name, "leader")
available_arms.append(arm_id)
return available_arms
def calibrate(robot: Robot, arms: list[str] | None):
available_arms = get_available_arms(robot)
unknown_arms = [arm_id for arm_id in arms if arm_id not in available_arms]
available_arms_str = " ".join(available_arms)
unknown_arms_str = " ".join(unknown_arms)
@@ -323,6 +327,7 @@ def record(
tags=None,
num_image_writers_per_camera=4,
force_override=False,
display_cameras=True,
):
# TODO(rcadene): Add option to record logs
# TODO(rcadene): Clean this function via decomposition in higher level functions
@@ -333,9 +338,6 @@ def record(
f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})."
)
if not video:
raise NotImplementedError()
if not robot.is_connected:
robot.connect()
@@ -359,7 +361,7 @@ def record(
episode_index = 0
if is_headless():
logging.info(
logging.warning(
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
@@ -427,7 +429,7 @@ def record(
else:
observation = robot.capture_observation()
if not is_headless():
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
@@ -445,6 +447,7 @@ def record(
# Using `with` to exist smoothly if an execption is raised.
futures = []
num_image_writers = num_image_writers_per_camera * len(robot.cameras)
num_image_writers = max(num_image_writers, 1)
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor:
# Start recording all episodes
while episode_index < num_episodes:
@@ -472,7 +475,7 @@ def record(
)
]
if not is_headless():
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
@@ -545,15 +548,23 @@ def record(
num_frames = frame_index
for key in image_keys:
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
fname = f"{key}_episode_{episode_index:06d}.mp4"
video_path = local_dir / "videos" / fname
if video_path.exists():
video_path.unlink()
# Store the reference to the video frame, even tho the videos are not yet encoded
ep_dict[key] = []
for i in range(num_frames):
ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps})
if video:
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
fname = f"{key}_episode_{episode_index:06d}.mp4"
video_path = local_dir / "videos" / fname
if video_path.exists():
video_path.unlink()
# Store the reference to the video frame, even tho the videos are not yet encoded
ep_dict[key] = []
for i in range(num_frames):
ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps})
else:
imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
ep_dict[key] = []
for i in range(num_frames):
img_path = imgs_dir / f"frame_{i:06d}.png"
ep_dict[key].append({"path": str(img_path)})
for key in not_image_keys:
ep_dict[key] = torch.stack(ep_dict[key])
@@ -612,26 +623,27 @@ def record(
break
robot.disconnect()
if not is_headless():
if display_cameras and not is_headless():
cv2.destroyAllWindows()
num_episodes = episode_index
logging.info("Encoding videos")
say("Encoding videos")
# Use ffmpeg to convert frames stored as png into mp4 videos
for episode_index in tqdm.tqdm(range(num_episodes)):
for key in image_keys:
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
fname = f"{key}_episode_{episode_index:06d}.mp4"
video_path = local_dir / "videos" / fname
if video_path.exists():
# Skip if video is already encoded. Could be the case when resuming data recording.
continue
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
# since video encoding with ffmpeg is already using multithreading.
encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True)
shutil.rmtree(tmp_imgs_dir)
if video:
logging.info("Encoding videos")
say("Encoding videos")
# Use ffmpeg to convert frames stored as png into mp4 videos
for episode_index in tqdm.tqdm(range(num_episodes)):
for key in image_keys:
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
fname = f"{key}_episode_{episode_index:06d}.mp4"
video_path = local_dir / "videos" / fname
if video_path.exists():
# Skip if video is already encoded. Could be the case when resuming data recording.
continue
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
# since video encoding with ffmpeg is already using multithreading.
encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True)
shutil.rmtree(tmp_imgs_dir)
logging.info("Concatenating episodes")
ep_dicts = []