diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index a23c21c6..995b19c9 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -337,9 +337,6 @@ def record( f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})." ) - if not video: - raise NotImplementedError() - if not robot.is_connected: robot.connect() @@ -550,15 +547,23 @@ def record( num_frames = frame_index for key in image_keys: - tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}" - fname = f"{key}_episode_{episode_index:06d}.mp4" - video_path = local_dir / "videos" / fname - if video_path.exists(): - video_path.unlink() - # Store the reference to the video frame, even tho the videos are not yet encoded - ep_dict[key] = [] - for i in range(num_frames): - ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps}) + if video: + tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}" + fname = f"{key}_episode_{episode_index:06d}.mp4" + video_path = local_dir / "videos" / fname + if video_path.exists(): + video_path.unlink() + # Store the reference to the video frame, even tho the videos are not yet encoded + ep_dict[key] = [] + for i in range(num_frames): + ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps}) + + else: + imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}" + ep_dict[key] = [] + for frame_index in range(num_frames): + img_path = imgs_dir / f"frame_{frame_index:06d}.png" + ep_dict[key].append(img_path) for key in not_image_keys: ep_dict[key] = torch.stack(ep_dict[key]) @@ -622,21 +627,22 @@ def record( num_episodes = episode_index - logging.info("Encoding videos") - say("Encoding videos") - # Use ffmpeg to convert frames stored as png into mp4 videos - for episode_index in tqdm.tqdm(range(num_episodes)): - for key in image_keys: - tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}" - fname = f"{key}_episode_{episode_index:06d}.mp4" - video_path = local_dir / "videos" / fname - if video_path.exists(): - # Skip if video is already encoded. Could be the case when resuming data recording. - continue - # note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, - # since video encoding with ffmpeg is already using multithreading. - encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True) - shutil.rmtree(tmp_imgs_dir) + if video: + logging.info("Encoding videos") + say("Encoding videos") + # Use ffmpeg to convert frames stored as png into mp4 videos + for episode_index in tqdm.tqdm(range(num_episodes)): + for key in image_keys: + tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}" + fname = f"{key}_episode_{episode_index:06d}.mp4" + video_path = local_dir / "videos" / fname + if video_path.exists(): + # Skip if video is already encoded. Could be the case when resuming data recording. + continue + # note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, + # since video encoding with ffmpeg is already using multithreading. + encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True) + shutil.rmtree(tmp_imgs_dir) logging.info("Concatenating episodes") ep_dicts = [] diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index 14fe9158..857d809b 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -69,6 +69,7 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock): num_episodes=2, run_compute_stats=False, push_to_hub=False, + video=False, ) @@ -91,6 +92,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): episode_time_s=1, num_episodes=2, push_to_hub=False, + video=False, ) replay(robot, episode=0, fps=30, root=root, repo_id=repo_id) @@ -106,6 +108,15 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats) - record(robot, policy, cfg, warmup_time_s=1, episode_time_s=1, run_compute_stats=False, push_to_hub=False) + record( + robot, + policy, + cfg, + warmup_time_s=1, + episode_time_s=1, + run_compute_stats=False, + push_to_hub=False, + video=False, + ) del robot