Add add_episode & task logic

This commit is contained in:
Simon Alibert
2024-10-21 19:30:20 +02:00
parent 9ebf8b88ec
commit 299451af81
4 changed files with 203 additions and 18 deletions

View File

@@ -109,8 +109,6 @@ from lerobot.common.datasets.image_writer import ImageWriter
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.populate_dataset import (
create_lerobot_dataset,
delete_current_episode,
save_current_episode,
)
from lerobot.common.robot_devices.control_utils import (
control_loop,
@@ -195,6 +193,7 @@ def record(
robot: Robot,
root: str,
repo_id: str,
single_task: str,
pretrained_policy_name_or_path: str | None = None,
policy_overrides: List[str] | None = None,
fps: int | None = None,
@@ -219,6 +218,11 @@ def record(
device = None
use_amp = None
if single_task:
task = single_task
else:
raise NotImplementedError("Only single-task recording is supported for now")
# Load pretrained policy
if pretrained_policy_name_or_path is not None:
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
@@ -235,8 +239,8 @@ def record(
sanity_check_dataset_name(repo_id, policy)
image_writer = ImageWriter(
write_dir=root,
num_image_writer_processes=num_image_writer_processes,
num_image_writer_threads=num_image_writer_threads_per_camera * robot.num_cameras,
num_processes=num_image_writer_processes,
num_threads=num_image_writer_threads_per_camera * robot.num_cameras,
)
dataset = LeRobotDataset.create(repo_id, fps, robot, image_writer=image_writer)
@@ -261,7 +265,12 @@ def record(
if recorded_episodes >= num_episodes:
break
episode_index = dataset["num_episodes"]
# TODO(aliberts): add task prompt for multitask here. Might need to temporarily disable event if
# input() messes with them.
# if multi_task:
# task = input("Enter your task description: ")
episode_index = dataset.episode_buffer["episode_index"]
log_say(f"Recording episode {episode_index}", play_sounds)
record_episode(
dataset=dataset,
@@ -289,11 +298,11 @@ def record(
log_say("Re-record episode", play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
delete_current_episode(dataset)
dataset.delete_episode()
continue
# Increment by one dataset["current_episode_index"]
save_current_episode(dataset)
dataset.add_episode(task)
if events["stop_recording"]:
break
@@ -378,9 +387,21 @@ if __name__ == "__main__":
)
parser_record = subparsers.add_parser("record", parents=[base_parser])
task_args = parser_record.add_mutually_exclusive_group(required=True)
parser_record.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
)
task_args.add_argument(
"--single-task",
type=str,
help="A short but accurate description of the task performed during the recording.",
)
# TODO(aliberts): add multi-task support
# task_args.add_argument(
# "--multi-task",
# type=int,
# help="You will need to enter the task performed at the start of each episode.",
# )
parser_record.add_argument(
"--root",
type=Path,