Add support for video=False in record (no tested yet)

This commit is contained in:
Remi Cadene
2024-09-26 11:41:19 +02:00
parent 2c0171632f
commit 500d505bf6
2 changed files with 45 additions and 28 deletions

View File

@@ -337,9 +337,6 @@ def record(
f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})." f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})."
) )
if not video:
raise NotImplementedError()
if not robot.is_connected: if not robot.is_connected:
robot.connect() robot.connect()
@@ -550,6 +547,7 @@ def record(
num_frames = frame_index num_frames = frame_index
for key in image_keys: for key in image_keys:
if video:
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}" tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
fname = f"{key}_episode_{episode_index:06d}.mp4" fname = f"{key}_episode_{episode_index:06d}.mp4"
video_path = local_dir / "videos" / fname video_path = local_dir / "videos" / fname
@@ -560,6 +558,13 @@ def record(
for i in range(num_frames): for i in range(num_frames):
ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps}) ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps})
else:
imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
ep_dict[key] = []
for frame_index in range(num_frames):
img_path = imgs_dir / f"frame_{frame_index:06d}.png"
ep_dict[key].append(img_path)
for key in not_image_keys: for key in not_image_keys:
ep_dict[key] = torch.stack(ep_dict[key]) ep_dict[key] = torch.stack(ep_dict[key])
@@ -622,6 +627,7 @@ def record(
num_episodes = episode_index num_episodes = episode_index
if video:
logging.info("Encoding videos") logging.info("Encoding videos")
say("Encoding videos") say("Encoding videos")
# Use ffmpeg to convert frames stored as png into mp4 videos # Use ffmpeg to convert frames stored as png into mp4 videos

View File

@@ -69,6 +69,7 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
num_episodes=2, num_episodes=2,
run_compute_stats=False, run_compute_stats=False,
push_to_hub=False, push_to_hub=False,
video=False,
) )
@@ -91,6 +92,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
episode_time_s=1, episode_time_s=1,
num_episodes=2, num_episodes=2,
push_to_hub=False, push_to_hub=False,
video=False,
) )
replay(robot, episode=0, fps=30, root=root, repo_id=repo_id) replay(robot, episode=0, fps=30, root=root, repo_id=repo_id)
@@ -106,6 +108,15 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats) policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
record(robot, policy, cfg, warmup_time_s=1, episode_time_s=1, run_compute_stats=False, push_to_hub=False) record(
robot,
policy,
cfg,
warmup_time_s=1,
episode_time_s=1,
run_compute_stats=False,
push_to_hub=False,
video=False,
)
del robot del robot