Add support for video=False in record (no tested yet)

This commit is contained in:
Remi Cadene
2024-09-26 11:41:19 +02:00
parent 2c0171632f
commit 500d505bf6
2 changed files with 45 additions and 28 deletions

View File

@@ -337,9 +337,6 @@ def record(
f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})." f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})."
) )
if not video:
raise NotImplementedError()
if not robot.is_connected: if not robot.is_connected:
robot.connect() robot.connect()
@@ -550,15 +547,23 @@ def record(
num_frames = frame_index num_frames = frame_index
for key in image_keys: for key in image_keys:
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}" if video:
fname = f"{key}_episode_{episode_index:06d}.mp4" tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
video_path = local_dir / "videos" / fname fname = f"{key}_episode_{episode_index:06d}.mp4"
if video_path.exists(): video_path = local_dir / "videos" / fname
video_path.unlink() if video_path.exists():
# Store the reference to the video frame, even tho the videos are not yet encoded video_path.unlink()
ep_dict[key] = [] # Store the reference to the video frame, even tho the videos are not yet encoded
for i in range(num_frames): ep_dict[key] = []
ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps}) for i in range(num_frames):
ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps})
else:
imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
ep_dict[key] = []
for frame_index in range(num_frames):
img_path = imgs_dir / f"frame_{frame_index:06d}.png"
ep_dict[key].append(img_path)
for key in not_image_keys: for key in not_image_keys:
ep_dict[key] = torch.stack(ep_dict[key]) ep_dict[key] = torch.stack(ep_dict[key])
@@ -622,21 +627,22 @@ def record(
num_episodes = episode_index num_episodes = episode_index
logging.info("Encoding videos") if video:
say("Encoding videos") logging.info("Encoding videos")
# Use ffmpeg to convert frames stored as png into mp4 videos say("Encoding videos")
for episode_index in tqdm.tqdm(range(num_episodes)): # Use ffmpeg to convert frames stored as png into mp4 videos
for key in image_keys: for episode_index in tqdm.tqdm(range(num_episodes)):
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}" for key in image_keys:
fname = f"{key}_episode_{episode_index:06d}.mp4" tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
video_path = local_dir / "videos" / fname fname = f"{key}_episode_{episode_index:06d}.mp4"
if video_path.exists(): video_path = local_dir / "videos" / fname
# Skip if video is already encoded. Could be the case when resuming data recording. if video_path.exists():
continue # Skip if video is already encoded. Could be the case when resuming data recording.
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, continue
# since video encoding with ffmpeg is already using multithreading. # note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True) # since video encoding with ffmpeg is already using multithreading.
shutil.rmtree(tmp_imgs_dir) encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True)
shutil.rmtree(tmp_imgs_dir)
logging.info("Concatenating episodes") logging.info("Concatenating episodes")
ep_dicts = [] ep_dicts = []

View File

@@ -69,6 +69,7 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
num_episodes=2, num_episodes=2,
run_compute_stats=False, run_compute_stats=False,
push_to_hub=False, push_to_hub=False,
video=False,
) )
@@ -91,6 +92,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
episode_time_s=1, episode_time_s=1,
num_episodes=2, num_episodes=2,
push_to_hub=False, push_to_hub=False,
video=False,
) )
replay(robot, episode=0, fps=30, root=root, repo_id=repo_id) replay(robot, episode=0, fps=30, root=root, repo_id=repo_id)
@@ -106,6 +108,15 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats) policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
record(robot, policy, cfg, warmup_time_s=1, episode_time_s=1, run_compute_stats=False, push_to_hub=False) record(
robot,
policy,
cfg,
warmup_time_s=1,
episode_time_s=1,
run_compute_stats=False,
push_to_hub=False,
video=False,
)
del robot del robot