Add local_files_only, encode_videos, fix bugs to pass tests (WIP)

This commit is contained in:
Simon Alibert
2024-10-22 19:57:52 +02:00
parent e991a31061
commit a805458c7e
4 changed files with 183 additions and 80 deletions

View File

@@ -194,19 +194,17 @@ def record(
pretrained_policy_name_or_path: str | None = None,
policy_overrides: List[str] | None = None,
fps: int | None = None,
warmup_time_s=2,
episode_time_s=10,
reset_time_s=5,
num_episodes=50,
video=True,
run_compute_stats=True,
push_to_hub=True,
tags=None,
num_image_writer_processes=0,
num_image_writer_threads_per_camera=4,
force_override=False,
display_cameras=True,
play_sounds=True,
warmup_time_s: int | float = 2,
episode_time_s: int | float = 10,
reset_time_s: int | float = 5,
num_episodes: int = 50,
video: bool = True,
run_compute_stats: bool = True,
push_to_hub: bool = True,
num_image_writer_processes: int = 0,
num_image_writer_threads_per_camera: int = 4,
display_cameras: bool = True,
play_sounds: bool = True,
) -> LeRobotDataset:
# TODO(rcadene): Add option to record logs
listener = None
@@ -234,12 +232,18 @@ def record(
# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy)
image_writer = ImageWriter(
write_dir=root,
num_processes=num_image_writer_processes,
num_threads=num_image_writer_threads_per_camera * robot.num_cameras,
if len(robot.cameras) > 0:
image_writer = ImageWriter(
write_dir=root,
num_processes=num_image_writer_processes,
num_threads=num_image_writer_threads_per_camera * robot.num_cameras,
)
else:
image_writer = None
dataset = LeRobotDataset.create(
repo_id, fps, robot, root=root, image_writer=image_writer, use_videos=video
)
dataset = LeRobotDataset.create(repo_id, fps, robot, root=root, image_writer=image_writer)
if not robot.is_connected:
robot.connect()
@@ -307,8 +311,9 @@ def record(
log_say("Stop recording", play_sounds, blocking=True)
stop_recording(robot, listener, display_cameras)
logging.info("Waiting for image writer to terminate...")
dataset.image_writer.stop()
if dataset.image_writer is not None:
logging.info("Waiting for image writer to terminate...")
dataset.image_writer.stop()
dataset.consolidate(run_compute_stats)
@@ -322,27 +327,28 @@ def record(
@safe_disconnect
def replay(
robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug", play_sounds=True
robot: Robot,
root: Path,
repo_id: str,
episode: int,
fps: int | None = None,
play_sounds: bool = True,
local_files_only: bool = True,
):
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
# TODO(rcadene): Add option to record logs
local_dir = Path(root) / repo_id
if not local_dir.exists():
raise ValueError(local_dir)
dataset = LeRobotDataset(repo_id, root=root)
items = dataset.hf_dataset.select_columns("action")
from_idx = dataset.episode_data_index["from"][episode].item()
to_idx = dataset.episode_data_index["to"][episode].item()
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
actions = dataset.hf_dataset.select_columns("action")
if not robot.is_connected:
robot.connect()
log_say("Replaying episode", play_sounds, blocking=True)
for idx in range(from_idx, to_idx):
for idx in range(dataset.num_samples):
start_episode_t = time.perf_counter()
action = items[idx]["action"]
action = actions[idx]["action"]
robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t