Add add_episode & task logic
This commit is contained in:
@@ -109,8 +109,6 @@ from lerobot.common.datasets.image_writer import ImageWriter
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.populate_dataset import (
|
||||
create_lerobot_dataset,
|
||||
delete_current_episode,
|
||||
save_current_episode,
|
||||
)
|
||||
from lerobot.common.robot_devices.control_utils import (
|
||||
control_loop,
|
||||
@@ -195,6 +193,7 @@ def record(
|
||||
robot: Robot,
|
||||
root: str,
|
||||
repo_id: str,
|
||||
single_task: str,
|
||||
pretrained_policy_name_or_path: str | None = None,
|
||||
policy_overrides: List[str] | None = None,
|
||||
fps: int | None = None,
|
||||
@@ -219,6 +218,11 @@ def record(
|
||||
device = None
|
||||
use_amp = None
|
||||
|
||||
if single_task:
|
||||
task = single_task
|
||||
else:
|
||||
raise NotImplementedError("Only single-task recording is supported for now")
|
||||
|
||||
# Load pretrained policy
|
||||
if pretrained_policy_name_or_path is not None:
|
||||
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
|
||||
@@ -235,8 +239,8 @@ def record(
|
||||
sanity_check_dataset_name(repo_id, policy)
|
||||
image_writer = ImageWriter(
|
||||
write_dir=root,
|
||||
num_image_writer_processes=num_image_writer_processes,
|
||||
num_image_writer_threads=num_image_writer_threads_per_camera * robot.num_cameras,
|
||||
num_processes=num_image_writer_processes,
|
||||
num_threads=num_image_writer_threads_per_camera * robot.num_cameras,
|
||||
)
|
||||
dataset = LeRobotDataset.create(repo_id, fps, robot, image_writer=image_writer)
|
||||
|
||||
@@ -261,7 +265,12 @@ def record(
|
||||
if recorded_episodes >= num_episodes:
|
||||
break
|
||||
|
||||
episode_index = dataset["num_episodes"]
|
||||
# TODO(aliberts): add task prompt for multitask here. Might need to temporarily disable event if
|
||||
# input() messes with them.
|
||||
# if multi_task:
|
||||
# task = input("Enter your task description: ")
|
||||
|
||||
episode_index = dataset.episode_buffer["episode_index"]
|
||||
log_say(f"Recording episode {episode_index}", play_sounds)
|
||||
record_episode(
|
||||
dataset=dataset,
|
||||
@@ -289,11 +298,11 @@ def record(
|
||||
log_say("Re-record episode", play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
delete_current_episode(dataset)
|
||||
dataset.delete_episode()
|
||||
continue
|
||||
|
||||
# Increment by one dataset["current_episode_index"]
|
||||
save_current_episode(dataset)
|
||||
dataset.add_episode(task)
|
||||
|
||||
if events["stop_recording"]:
|
||||
break
|
||||
@@ -378,9 +387,21 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
parser_record = subparsers.add_parser("record", parents=[base_parser])
|
||||
task_args = parser_record.add_mutually_exclusive_group(required=True)
|
||||
parser_record.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
task_args.add_argument(
|
||||
"--single-task",
|
||||
type=str,
|
||||
help="A short but accurate description of the task performed during the recording.",
|
||||
)
|
||||
# TODO(aliberts): add multi-task support
|
||||
# task_args.add_argument(
|
||||
# "--multi-task",
|
||||
# type=int,
|
||||
# help="You will need to enter the task performed at the start of each episode.",
|
||||
# )
|
||||
parser_record.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
|
||||
Reference in New Issue
Block a user