forked from tangger/lerobot
Add support for video=False in record (no tested yet)
This commit is contained in:
@@ -337,9 +337,6 @@ def record(
|
|||||||
f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})."
|
f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})."
|
||||||
)
|
)
|
||||||
|
|
||||||
if not video:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
if not robot.is_connected:
|
if not robot.is_connected:
|
||||||
robot.connect()
|
robot.connect()
|
||||||
|
|
||||||
@@ -550,15 +547,23 @@ def record(
|
|||||||
num_frames = frame_index
|
num_frames = frame_index
|
||||||
|
|
||||||
for key in image_keys:
|
for key in image_keys:
|
||||||
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
if video:
|
||||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
||||||
video_path = local_dir / "videos" / fname
|
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||||
if video_path.exists():
|
video_path = local_dir / "videos" / fname
|
||||||
video_path.unlink()
|
if video_path.exists():
|
||||||
# Store the reference to the video frame, even tho the videos are not yet encoded
|
video_path.unlink()
|
||||||
ep_dict[key] = []
|
# Store the reference to the video frame, even tho the videos are not yet encoded
|
||||||
for i in range(num_frames):
|
ep_dict[key] = []
|
||||||
ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps})
|
for i in range(num_frames):
|
||||||
|
ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps})
|
||||||
|
|
||||||
|
else:
|
||||||
|
imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
||||||
|
ep_dict[key] = []
|
||||||
|
for frame_index in range(num_frames):
|
||||||
|
img_path = imgs_dir / f"frame_{frame_index:06d}.png"
|
||||||
|
ep_dict[key].append(img_path)
|
||||||
|
|
||||||
for key in not_image_keys:
|
for key in not_image_keys:
|
||||||
ep_dict[key] = torch.stack(ep_dict[key])
|
ep_dict[key] = torch.stack(ep_dict[key])
|
||||||
@@ -622,21 +627,22 @@ def record(
|
|||||||
|
|
||||||
num_episodes = episode_index
|
num_episodes = episode_index
|
||||||
|
|
||||||
logging.info("Encoding videos")
|
if video:
|
||||||
say("Encoding videos")
|
logging.info("Encoding videos")
|
||||||
# Use ffmpeg to convert frames stored as png into mp4 videos
|
say("Encoding videos")
|
||||||
for episode_index in tqdm.tqdm(range(num_episodes)):
|
# Use ffmpeg to convert frames stored as png into mp4 videos
|
||||||
for key in image_keys:
|
for episode_index in tqdm.tqdm(range(num_episodes)):
|
||||||
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
for key in image_keys:
|
||||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
||||||
video_path = local_dir / "videos" / fname
|
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||||
if video_path.exists():
|
video_path = local_dir / "videos" / fname
|
||||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
if video_path.exists():
|
||||||
continue
|
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||||
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
continue
|
||||||
# since video encoding with ffmpeg is already using multithreading.
|
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||||
encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True)
|
# since video encoding with ffmpeg is already using multithreading.
|
||||||
shutil.rmtree(tmp_imgs_dir)
|
encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True)
|
||||||
|
shutil.rmtree(tmp_imgs_dir)
|
||||||
|
|
||||||
logging.info("Concatenating episodes")
|
logging.info("Concatenating episodes")
|
||||||
ep_dicts = []
|
ep_dicts = []
|
||||||
|
|||||||
@@ -69,6 +69,7 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
|
|||||||
num_episodes=2,
|
num_episodes=2,
|
||||||
run_compute_stats=False,
|
run_compute_stats=False,
|
||||||
push_to_hub=False,
|
push_to_hub=False,
|
||||||
|
video=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -91,6 +92,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
|||||||
episode_time_s=1,
|
episode_time_s=1,
|
||||||
num_episodes=2,
|
num_episodes=2,
|
||||||
push_to_hub=False,
|
push_to_hub=False,
|
||||||
|
video=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
replay(robot, episode=0, fps=30, root=root, repo_id=repo_id)
|
replay(robot, episode=0, fps=30, root=root, repo_id=repo_id)
|
||||||
@@ -106,6 +108,15 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
|||||||
|
|
||||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||||
|
|
||||||
record(robot, policy, cfg, warmup_time_s=1, episode_time_s=1, run_compute_stats=False, push_to_hub=False)
|
record(
|
||||||
|
robot,
|
||||||
|
policy,
|
||||||
|
cfg,
|
||||||
|
warmup_time_s=1,
|
||||||
|
episode_time_s=1,
|
||||||
|
run_compute_stats=False,
|
||||||
|
push_to_hub=False,
|
||||||
|
video=False,
|
||||||
|
)
|
||||||
|
|
||||||
del robot
|
del robot
|
||||||
|
|||||||
Reference in New Issue
Block a user