Feat: Add Batched Video Encoding for Faster Dataset Recording (#1390)
* LeRobotDataset video encoding: updated `save_episode` method and added `batch_encode_videos` method to handle video encoding based on `batch_encoding_size`, allowing for both immediate and batched encoding. * LeRobotDataset video cleanup: Enabled individual episode cleanup and check for remaining PNG files before removing the `images` directory. * LeRobotDataset - VideoEncodingManager: added proper handling of pending episodes (encoding, cleaning) on exit or recording failures. * LeRobotDatasetMetadata: removed `update_video_info` to only update video info at episode index 0 encoding. * Adjusted the `record` function to utilize the new encoding management logic. * Removed `encode_videos` method from `LeRobotDataset` and `encode_episode_videos` outputs as they are nowhere used. --------- Signed-off-by: Xingdong Zuo <zuoxingdong@users.noreply.github.com> Co-authored-by: Xingdong Zuo <xingdong.zuo@navercorp.com> Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
This commit is contained in:
@@ -260,8 +260,6 @@ class LeRobotDatasetMetadata:
|
|||||||
|
|
||||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||||
self.info["total_videos"] += len(self.video_keys)
|
self.info["total_videos"] += len(self.video_keys)
|
||||||
if len(self.video_keys) > 0:
|
|
||||||
self.update_video_info()
|
|
||||||
|
|
||||||
write_info(self.info, self.root)
|
write_info(self.info, self.root)
|
||||||
|
|
||||||
@@ -342,6 +340,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
force_cache_sync: bool = False,
|
force_cache_sync: bool = False,
|
||||||
download_videos: bool = True,
|
download_videos: bool = True,
|
||||||
video_backend: str | None = None,
|
video_backend: str | None = None,
|
||||||
|
batch_encoding_size: int = 1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
2 modes are available for instantiating this class, depending on 2 different use cases:
|
2 modes are available for instantiating this class, depending on 2 different use cases:
|
||||||
@@ -443,6 +442,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
True.
|
True.
|
||||||
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
|
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
|
||||||
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||||
|
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
|
||||||
|
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
@@ -454,6 +455,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
self.revision = revision if revision else CODEBASE_VERSION
|
self.revision = revision if revision else CODEBASE_VERSION
|
||||||
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||||
self.delta_indices = None
|
self.delta_indices = None
|
||||||
|
self.batch_encoding_size = batch_encoding_size
|
||||||
|
self.episodes_since_last_encoding = 0
|
||||||
|
|
||||||
# Unused attributes
|
# Unused attributes
|
||||||
self.image_writer = None
|
self.image_writer = None
|
||||||
@@ -811,6 +814,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
"""
|
"""
|
||||||
This will save to disk the current episode in self.episode_buffer.
|
This will save to disk the current episode in self.episode_buffer.
|
||||||
|
|
||||||
|
Video encoding is handled automatically based on batch_encoding_size:
|
||||||
|
- If batch_encoding_size == 1: Videos are encoded immediately after each episode
|
||||||
|
- If batch_encoding_size > 1: Videos are encoded in batches.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will
|
episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will
|
||||||
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
|
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
|
||||||
@@ -850,14 +857,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
self._save_episode_table(episode_buffer, episode_index)
|
self._save_episode_table(episode_buffer, episode_index)
|
||||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||||
|
|
||||||
if len(self.meta.video_keys) > 0:
|
has_video_keys = len(self.meta.video_keys) > 0
|
||||||
video_paths = self.encode_episode_videos(episode_index)
|
use_batched_encoding = self.batch_encoding_size > 1
|
||||||
for key in self.meta.video_keys:
|
|
||||||
episode_buffer[key] = video_paths[key]
|
|
||||||
|
|
||||||
# `meta.save_episode` be executed after encoding the videos
|
if has_video_keys and not use_batched_encoding:
|
||||||
|
self.encode_episode_videos(episode_index)
|
||||||
|
|
||||||
|
# `meta.save_episode` should be executed after encoding the videos
|
||||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
||||||
|
|
||||||
|
# Check if we should trigger batch encoding
|
||||||
|
if has_video_keys and use_batched_encoding:
|
||||||
|
self.episodes_since_last_encoding += 1
|
||||||
|
if self.episodes_since_last_encoding == self.batch_encoding_size:
|
||||||
|
start_ep = self.num_episodes - self.batch_encoding_size
|
||||||
|
end_ep = self.num_episodes
|
||||||
|
logging.info(
|
||||||
|
f"Batch encoding {self.batch_encoding_size} videos for episodes {start_ep} to {end_ep - 1}"
|
||||||
|
)
|
||||||
|
self.batch_encode_videos(start_ep, end_ep)
|
||||||
|
self.episodes_since_last_encoding = 0
|
||||||
|
|
||||||
|
# Episode data index and timestamp checking
|
||||||
ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
|
ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
|
||||||
ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
|
ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
|
||||||
check_timestamps_sync(
|
check_timestamps_sync(
|
||||||
@@ -868,16 +889,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
self.tolerance_s,
|
self.tolerance_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
video_files = list(self.root.rglob("*.mp4"))
|
# Verify that we have one parquet file per episode and the number of video files matches the number of encoded episodes
|
||||||
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
|
|
||||||
|
|
||||||
parquet_files = list(self.root.rglob("*.parquet"))
|
parquet_files = list(self.root.rglob("*.parquet"))
|
||||||
assert len(parquet_files) == self.num_episodes
|
assert len(parquet_files) == self.num_episodes
|
||||||
|
video_files = list(self.root.rglob("*.mp4"))
|
||||||
# delete images
|
assert len(video_files) == (self.num_episodes - self.episodes_since_last_encoding) * len(
|
||||||
img_dir = self.root / "images"
|
self.meta.video_keys
|
||||||
if img_dir.is_dir():
|
)
|
||||||
shutil.rmtree(self.root / "images")
|
|
||||||
|
|
||||||
if not episode_data: # Reset the buffer
|
if not episode_data: # Reset the buffer
|
||||||
self.episode_buffer = self.create_episode_buffer()
|
self.episode_buffer = self.create_episode_buffer()
|
||||||
@@ -894,6 +912,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
def clear_episode_buffer(self) -> None:
|
def clear_episode_buffer(self) -> None:
|
||||||
episode_index = self.episode_buffer["episode_index"]
|
episode_index = self.episode_buffer["episode_index"]
|
||||||
|
|
||||||
|
# Clean up image files for the current episode buffer
|
||||||
if self.image_writer is not None:
|
if self.image_writer is not None:
|
||||||
for cam_key in self.meta.camera_keys:
|
for cam_key in self.meta.camera_keys:
|
||||||
img_dir = self._get_image_file_path(
|
img_dir = self._get_image_file_path(
|
||||||
@@ -930,25 +950,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
if self.image_writer is not None:
|
if self.image_writer is not None:
|
||||||
self.image_writer.wait_until_done()
|
self.image_writer.wait_until_done()
|
||||||
|
|
||||||
def encode_videos(self) -> None:
|
def encode_episode_videos(self, episode_index: int) -> None:
|
||||||
"""
|
"""
|
||||||
Use ffmpeg to convert frames stored as png into mp4 videos.
|
Use ffmpeg to convert frames stored as png into mp4 videos.
|
||||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||||
since video encoding with ffmpeg is already using multithreading.
|
since video encoding with ffmpeg is already using multithreading.
|
||||||
"""
|
|
||||||
for ep_idx in range(self.meta.total_episodes):
|
|
||||||
self.encode_episode_videos(ep_idx)
|
|
||||||
|
|
||||||
def encode_episode_videos(self, episode_index: int) -> dict:
|
This method handles video encoding steps:
|
||||||
|
- Video encoding via ffmpeg
|
||||||
|
- Video info updating in metadata
|
||||||
|
- Raw image cleanup
|
||||||
|
|
||||||
|
Args:
|
||||||
|
episode_index (int): Index of the episode to encode.
|
||||||
"""
|
"""
|
||||||
Use ffmpeg to convert frames stored as png into mp4 videos.
|
|
||||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
|
||||||
since video encoding with ffmpeg is already using multithreading.
|
|
||||||
"""
|
|
||||||
video_paths = {}
|
|
||||||
for key in self.meta.video_keys:
|
for key in self.meta.video_keys:
|
||||||
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
|
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
|
||||||
video_paths[key] = str(video_path)
|
|
||||||
if video_path.is_file():
|
if video_path.is_file():
|
||||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||||
continue
|
continue
|
||||||
@@ -956,8 +973,32 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
episode_index=episode_index, image_key=key, frame_index=0
|
episode_index=episode_index, image_key=key, frame_index=0
|
||||||
).parent
|
).parent
|
||||||
encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
|
encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
|
||||||
|
shutil.rmtree(img_dir)
|
||||||
|
|
||||||
return video_paths
|
# Update video info (only needed when first episode is encoded since it reads from episode 0)
|
||||||
|
if len(self.meta.video_keys) > 0 and episode_index == 0:
|
||||||
|
self.meta.update_video_info()
|
||||||
|
write_info(self.meta.info, self.meta.root) # ensure video info always written properly
|
||||||
|
|
||||||
|
def batch_encode_videos(self, start_episode: int = 0, end_episode: int | None = None) -> None:
|
||||||
|
"""
|
||||||
|
Batch encode videos for multiple episodes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_episode: Starting episode index (inclusive)
|
||||||
|
end_episode: Ending episode index (exclusive). If None, encodes all episodes from start_episode
|
||||||
|
"""
|
||||||
|
if end_episode is None:
|
||||||
|
end_episode = self.meta.total_episodes
|
||||||
|
|
||||||
|
logging.info(f"Starting batch video encoding for episodes {start_episode} to {end_episode - 1}")
|
||||||
|
|
||||||
|
# Encode all episodes with cleanup enabled for individual episodes
|
||||||
|
for ep_idx in range(start_episode, end_episode):
|
||||||
|
logging.info(f"Encoding videos for episode {ep_idx}")
|
||||||
|
self.encode_episode_videos(ep_idx)
|
||||||
|
|
||||||
|
logging.info("Batch video encoding completed")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
@@ -972,6 +1013,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
image_writer_processes: int = 0,
|
image_writer_processes: int = 0,
|
||||||
image_writer_threads: int = 0,
|
image_writer_threads: int = 0,
|
||||||
video_backend: str | None = None,
|
video_backend: str | None = None,
|
||||||
|
batch_encoding_size: int = 1,
|
||||||
) -> "LeRobotDataset":
|
) -> "LeRobotDataset":
|
||||||
"""Create a LeRobot Dataset from scratch in order to record data."""
|
"""Create a LeRobot Dataset from scratch in order to record data."""
|
||||||
obj = cls.__new__(cls)
|
obj = cls.__new__(cls)
|
||||||
@@ -988,6 +1030,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
obj.revision = None
|
obj.revision = None
|
||||||
obj.tolerance_s = tolerance_s
|
obj.tolerance_s = tolerance_s
|
||||||
obj.image_writer = None
|
obj.image_writer = None
|
||||||
|
obj.batch_encoding_size = batch_encoding_size
|
||||||
|
obj.episodes_since_last_encoding = 0
|
||||||
|
|
||||||
if image_writer_processes or image_writer_threads:
|
if image_writer_processes or image_writer_threads:
|
||||||
obj.start_image_writer(image_writer_processes, image_writer_threads)
|
obj.start_image_writer(image_writer_processes, image_writer_threads)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
import glob
|
import glob
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
|
import shutil
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -451,3 +452,66 @@ def get_image_pixel_channels(image: Image):
|
|||||||
return 4 # RGBA
|
return 4 # RGBA
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown format")
|
raise ValueError("Unknown format")
|
||||||
|
|
||||||
|
|
||||||
|
class VideoEncodingManager:
|
||||||
|
"""
|
||||||
|
Context manager that ensures proper video encoding and data cleanup even if exceptions occur.
|
||||||
|
|
||||||
|
This manager handles:
|
||||||
|
- Batch encoding for any remaining episodes when recording interrupted
|
||||||
|
- Cleaning up temporary image files from interrupted episodes
|
||||||
|
- Removing empty image directories
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: The LeRobotDataset instance
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dataset):
|
||||||
|
self.dataset = dataset
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
# Handle any remaining episodes that haven't been batch encoded
|
||||||
|
if self.dataset.episodes_since_last_encoding > 0:
|
||||||
|
if exc_type is not None:
|
||||||
|
logging.info("Exception occurred. Encoding remaining episodes before exit...")
|
||||||
|
else:
|
||||||
|
logging.info("Recording stopped. Encoding remaining episodes...")
|
||||||
|
|
||||||
|
start_ep = self.dataset.num_episodes - self.dataset.episodes_since_last_encoding
|
||||||
|
end_ep = self.dataset.num_episodes
|
||||||
|
logging.info(
|
||||||
|
f"Encoding remaining {self.dataset.episodes_since_last_encoding} episodes, "
|
||||||
|
f"from episode {start_ep} to {end_ep - 1}"
|
||||||
|
)
|
||||||
|
self.dataset.batch_encode_videos(start_ep, end_ep)
|
||||||
|
|
||||||
|
# Clean up episode images if recording was interrupted
|
||||||
|
if exc_type is not None:
|
||||||
|
interrupted_episode_index = self.dataset.num_episodes
|
||||||
|
for key in self.dataset.meta.video_keys:
|
||||||
|
img_dir = self.dataset._get_image_file_path(
|
||||||
|
episode_index=interrupted_episode_index, image_key=key, frame_index=0
|
||||||
|
).parent
|
||||||
|
if img_dir.exists():
|
||||||
|
logging.debug(
|
||||||
|
f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}"
|
||||||
|
)
|
||||||
|
shutil.rmtree(img_dir)
|
||||||
|
|
||||||
|
# Clean up any remaining images directory if it's empty
|
||||||
|
img_dir = self.dataset.root / "images"
|
||||||
|
# Check for any remaining PNG files
|
||||||
|
png_files = list(img_dir.rglob("*.png"))
|
||||||
|
if len(png_files) == 0:
|
||||||
|
# Only remove the images directory if no PNG files remain
|
||||||
|
if img_dir.exists():
|
||||||
|
shutil.rmtree(img_dir)
|
||||||
|
logging.debug("Cleaned up empty images directory")
|
||||||
|
else:
|
||||||
|
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
|
||||||
|
|
||||||
|
return False # Don't suppress the original exception
|
||||||
|
|||||||
@@ -73,6 +73,7 @@ from lerobot.configs.policies import PreTrainedConfig
|
|||||||
from lerobot.datasets.image_writer import safe_stop_image_writer
|
from lerobot.datasets.image_writer import safe_stop_image_writer
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||||
|
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||||
from lerobot.policies.factory import make_policy
|
from lerobot.policies.factory import make_policy
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.robots import ( # noqa: F401
|
from lerobot.robots import ( # noqa: F401
|
||||||
@@ -145,6 +146,9 @@ class DatasetRecordConfig:
|
|||||||
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
|
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
|
||||||
# Not enough threads might cause low camera fps.
|
# Not enough threads might cause low camera fps.
|
||||||
num_image_writer_threads_per_camera: int = 4
|
num_image_writer_threads_per_camera: int = 4
|
||||||
|
# Number of episodes to record before batch encoding videos
|
||||||
|
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
|
||||||
|
video_encoding_batch_size: int = 1
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.single_task is None:
|
if self.single_task is None:
|
||||||
@@ -298,6 +302,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
|||||||
dataset = LeRobotDataset(
|
dataset = LeRobotDataset(
|
||||||
cfg.dataset.repo_id,
|
cfg.dataset.repo_id,
|
||||||
root=cfg.dataset.root,
|
root=cfg.dataset.root,
|
||||||
|
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(robot, "cameras") and len(robot.cameras) > 0:
|
if hasattr(robot, "cameras") and len(robot.cameras) > 0:
|
||||||
@@ -318,6 +323,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
|||||||
use_videos=cfg.dataset.video,
|
use_videos=cfg.dataset.video,
|
||||||
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
||||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||||
|
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load pretrained policy
|
# Load pretrained policy
|
||||||
@@ -329,46 +335,47 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
|||||||
|
|
||||||
listener, events = init_keyboard_listener()
|
listener, events = init_keyboard_listener()
|
||||||
|
|
||||||
recorded_episodes = 0
|
with VideoEncodingManager(dataset):
|
||||||
while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]:
|
recorded_episodes = 0
|
||||||
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
|
while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||||
record_loop(
|
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
|
||||||
robot=robot,
|
|
||||||
events=events,
|
|
||||||
fps=cfg.dataset.fps,
|
|
||||||
teleop=teleop,
|
|
||||||
policy=policy,
|
|
||||||
dataset=dataset,
|
|
||||||
control_time_s=cfg.dataset.episode_time_s,
|
|
||||||
single_task=cfg.dataset.single_task,
|
|
||||||
display_data=cfg.display_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Execute a few seconds without recording to give time to manually reset the environment
|
|
||||||
# Skip reset for the last episode to be recorded
|
|
||||||
if not events["stop_recording"] and (
|
|
||||||
(recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"]
|
|
||||||
):
|
|
||||||
log_say("Reset the environment", cfg.play_sounds)
|
|
||||||
record_loop(
|
record_loop(
|
||||||
robot=robot,
|
robot=robot,
|
||||||
events=events,
|
events=events,
|
||||||
fps=cfg.dataset.fps,
|
fps=cfg.dataset.fps,
|
||||||
teleop=teleop,
|
teleop=teleop,
|
||||||
control_time_s=cfg.dataset.reset_time_s,
|
policy=policy,
|
||||||
|
dataset=dataset,
|
||||||
|
control_time_s=cfg.dataset.episode_time_s,
|
||||||
single_task=cfg.dataset.single_task,
|
single_task=cfg.dataset.single_task,
|
||||||
display_data=cfg.display_data,
|
display_data=cfg.display_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
if events["rerecord_episode"]:
|
# Execute a few seconds without recording to give time to manually reset the environment
|
||||||
log_say("Re-record episode", cfg.play_sounds)
|
# Skip reset for the last episode to be recorded
|
||||||
events["rerecord_episode"] = False
|
if not events["stop_recording"] and (
|
||||||
events["exit_early"] = False
|
(recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"]
|
||||||
dataset.clear_episode_buffer()
|
):
|
||||||
continue
|
log_say("Reset the environment", cfg.play_sounds)
|
||||||
|
record_loop(
|
||||||
|
robot=robot,
|
||||||
|
events=events,
|
||||||
|
fps=cfg.dataset.fps,
|
||||||
|
teleop=teleop,
|
||||||
|
control_time_s=cfg.dataset.reset_time_s,
|
||||||
|
single_task=cfg.dataset.single_task,
|
||||||
|
display_data=cfg.display_data,
|
||||||
|
)
|
||||||
|
|
||||||
dataset.save_episode()
|
if events["rerecord_episode"]:
|
||||||
recorded_episodes += 1
|
log_say("Re-record episode", cfg.play_sounds)
|
||||||
|
events["rerecord_episode"] = False
|
||||||
|
events["exit_early"] = False
|
||||||
|
dataset.clear_episode_buffer()
|
||||||
|
continue
|
||||||
|
|
||||||
|
dataset.save_episode()
|
||||||
|
recorded_episodes += 1
|
||||||
|
|
||||||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user