Add local_files_only, encode_videos, fix bugs to pass tests (WIP)
This commit is contained in:
@@ -194,19 +194,17 @@ def record(
|
||||
pretrained_policy_name_or_path: str | None = None,
|
||||
policy_overrides: List[str] | None = None,
|
||||
fps: int | None = None,
|
||||
warmup_time_s=2,
|
||||
episode_time_s=10,
|
||||
reset_time_s=5,
|
||||
num_episodes=50,
|
||||
video=True,
|
||||
run_compute_stats=True,
|
||||
push_to_hub=True,
|
||||
tags=None,
|
||||
num_image_writer_processes=0,
|
||||
num_image_writer_threads_per_camera=4,
|
||||
force_override=False,
|
||||
display_cameras=True,
|
||||
play_sounds=True,
|
||||
warmup_time_s: int | float = 2,
|
||||
episode_time_s: int | float = 10,
|
||||
reset_time_s: int | float = 5,
|
||||
num_episodes: int = 50,
|
||||
video: bool = True,
|
||||
run_compute_stats: bool = True,
|
||||
push_to_hub: bool = True,
|
||||
num_image_writer_processes: int = 0,
|
||||
num_image_writer_threads_per_camera: int = 4,
|
||||
display_cameras: bool = True,
|
||||
play_sounds: bool = True,
|
||||
) -> LeRobotDataset:
|
||||
# TODO(rcadene): Add option to record logs
|
||||
listener = None
|
||||
@@ -234,12 +232,18 @@ def record(
|
||||
|
||||
# Create empty dataset or load existing saved episodes
|
||||
sanity_check_dataset_name(repo_id, policy)
|
||||
image_writer = ImageWriter(
|
||||
write_dir=root,
|
||||
num_processes=num_image_writer_processes,
|
||||
num_threads=num_image_writer_threads_per_camera * robot.num_cameras,
|
||||
if len(robot.cameras) > 0:
|
||||
image_writer = ImageWriter(
|
||||
write_dir=root,
|
||||
num_processes=num_image_writer_processes,
|
||||
num_threads=num_image_writer_threads_per_camera * robot.num_cameras,
|
||||
)
|
||||
else:
|
||||
image_writer = None
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id, fps, robot, root=root, image_writer=image_writer, use_videos=video
|
||||
)
|
||||
dataset = LeRobotDataset.create(repo_id, fps, robot, root=root, image_writer=image_writer)
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
@@ -307,8 +311,9 @@ def record(
|
||||
log_say("Stop recording", play_sounds, blocking=True)
|
||||
stop_recording(robot, listener, display_cameras)
|
||||
|
||||
logging.info("Waiting for image writer to terminate...")
|
||||
dataset.image_writer.stop()
|
||||
if dataset.image_writer is not None:
|
||||
logging.info("Waiting for image writer to terminate...")
|
||||
dataset.image_writer.stop()
|
||||
|
||||
dataset.consolidate(run_compute_stats)
|
||||
|
||||
@@ -322,27 +327,28 @@ def record(
|
||||
|
||||
@safe_disconnect
|
||||
def replay(
|
||||
robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug", play_sounds=True
|
||||
robot: Robot,
|
||||
root: Path,
|
||||
repo_id: str,
|
||||
episode: int,
|
||||
fps: int | None = None,
|
||||
play_sounds: bool = True,
|
||||
local_files_only: bool = True,
|
||||
):
|
||||
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
|
||||
# TODO(rcadene): Add option to record logs
|
||||
local_dir = Path(root) / repo_id
|
||||
if not local_dir.exists():
|
||||
raise ValueError(local_dir)
|
||||
|
||||
dataset = LeRobotDataset(repo_id, root=root)
|
||||
items = dataset.hf_dataset.select_columns("action")
|
||||
from_idx = dataset.episode_data_index["from"][episode].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode].item()
|
||||
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
log_say("Replaying episode", play_sounds, blocking=True)
|
||||
for idx in range(from_idx, to_idx):
|
||||
for idx in range(dataset.num_samples):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action = items[idx]["action"]
|
||||
action = actions[idx]["action"]
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
|
||||
Reference in New Issue
Block a user