@@ -29,7 +29,6 @@ python lerobot/scripts/control_robot.py teleoperate \
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--fps 30 \
|
||||
--root tmp/data \
|
||||
--repo-id $USER/koch_test \
|
||||
--num-episodes 1 \
|
||||
--run-compute-stats 0
|
||||
@@ -38,7 +37,6 @@ python lerobot/scripts/control_robot.py record \
|
||||
- Visualize dataset:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset.py \
|
||||
--root tmp/data \
|
||||
--repo-id $USER/koch_test \
|
||||
--episode-index 0
|
||||
```
|
||||
@@ -47,7 +45,6 @@ python lerobot/scripts/visualize_dataset.py \
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py replay \
|
||||
--fps 30 \
|
||||
--root tmp/data \
|
||||
--repo-id $USER/koch_test \
|
||||
--episode 0
|
||||
```
|
||||
@@ -57,7 +54,6 @@ python lerobot/scripts/control_robot.py replay \
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id $USER/koch_pick_place_lego \
|
||||
--num-episodes 50 \
|
||||
--warmup-time-s 2 \
|
||||
@@ -77,7 +73,7 @@ To avoid resuming by deleting the dataset, use `--force-override 1`.
|
||||
|
||||
- Train on this dataset with the ACT policy:
|
||||
```bash
|
||||
DATA_DIR=data python lerobot/scripts/train.py \
|
||||
python lerobot/scripts/train.py \
|
||||
policy=act_koch_real \
|
||||
env=koch_real \
|
||||
dataset_repo_id=$USER/koch_pick_place_lego \
|
||||
@@ -88,7 +84,6 @@ DATA_DIR=data python lerobot/scripts/train.py \
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id $USER/eval_act_koch_real \
|
||||
--num-episodes 10 \
|
||||
--warmup-time-s 2 \
|
||||
@@ -106,12 +101,6 @@ from typing import List
|
||||
|
||||
# from safetensors.torch import load_file, save_file
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.populate_dataset import (
|
||||
create_lerobot_dataset,
|
||||
delete_current_episode,
|
||||
init_dataset,
|
||||
save_current_episode,
|
||||
)
|
||||
from lerobot.common.robot_devices.control_utils import (
|
||||
control_loop,
|
||||
has_method,
|
||||
@@ -121,6 +110,7 @@ from lerobot.common.robot_devices.control_utils import (
|
||||
record_episode,
|
||||
reset_environment,
|
||||
sanity_check_dataset_name,
|
||||
sanity_check_dataset_robot_compatibility,
|
||||
stop_recording,
|
||||
warmup_record,
|
||||
)
|
||||
@@ -196,25 +186,28 @@ def teleoperate(
|
||||
@safe_disconnect
|
||||
def record(
|
||||
robot: Robot,
|
||||
root: str,
|
||||
root: Path,
|
||||
repo_id: str,
|
||||
single_task: str,
|
||||
pretrained_policy_name_or_path: str | None = None,
|
||||
policy_overrides: List[str] | None = None,
|
||||
fps: int | None = None,
|
||||
warmup_time_s=2,
|
||||
episode_time_s=10,
|
||||
reset_time_s=5,
|
||||
num_episodes=50,
|
||||
video=True,
|
||||
run_compute_stats=True,
|
||||
push_to_hub=True,
|
||||
tags=None,
|
||||
num_image_writer_processes=0,
|
||||
num_image_writer_threads_per_camera=4,
|
||||
force_override=False,
|
||||
display_cameras=True,
|
||||
play_sounds=True,
|
||||
):
|
||||
warmup_time_s: int | float = 2,
|
||||
episode_time_s: int | float = 10,
|
||||
reset_time_s: int | float = 5,
|
||||
num_episodes: int = 50,
|
||||
video: bool = True,
|
||||
run_compute_stats: bool = True,
|
||||
push_to_hub: bool = True,
|
||||
tags: list[str] | None = None,
|
||||
num_image_writer_processes: int = 0,
|
||||
num_image_writer_threads_per_camera: int = 4,
|
||||
display_cameras: bool = True,
|
||||
play_sounds: bool = True,
|
||||
resume: bool = False,
|
||||
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
|
||||
local_files_only: bool = False,
|
||||
) -> LeRobotDataset:
|
||||
# TODO(rcadene): Add option to record logs
|
||||
listener = None
|
||||
events = None
|
||||
@@ -222,6 +215,11 @@ def record(
|
||||
device = None
|
||||
use_amp = None
|
||||
|
||||
if single_task:
|
||||
task = single_task
|
||||
else:
|
||||
raise NotImplementedError("Only single-task recording is supported for now")
|
||||
|
||||
# Load pretrained policy
|
||||
if pretrained_policy_name_or_path is not None:
|
||||
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
|
||||
@@ -234,18 +232,29 @@ def record(
|
||||
f"There is a mismatch between the provided fps ({fps}) and the one from policy config ({policy_fps})."
|
||||
)
|
||||
|
||||
# Create empty dataset or load existing saved episodes
|
||||
sanity_check_dataset_name(repo_id, policy)
|
||||
dataset = init_dataset(
|
||||
repo_id,
|
||||
root,
|
||||
force_override,
|
||||
fps,
|
||||
video,
|
||||
write_images=robot.has_camera,
|
||||
num_image_writer_processes=num_image_writer_processes,
|
||||
num_image_writer_threads=num_image_writer_threads_per_camera * robot.num_cameras,
|
||||
)
|
||||
if resume:
|
||||
dataset = LeRobotDataset(
|
||||
repo_id,
|
||||
root=root,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
dataset.start_image_writer(
|
||||
num_processes=num_image_writer_processes,
|
||||
num_threads=num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video)
|
||||
else:
|
||||
# Create empty dataset or load existing saved episodes
|
||||
sanity_check_dataset_name(repo_id, policy)
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id,
|
||||
fps,
|
||||
root=root,
|
||||
robot=robot,
|
||||
use_videos=video,
|
||||
image_writer_processes=num_image_writer_processes,
|
||||
image_writer_threads=num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
@@ -263,12 +272,17 @@ def record(
|
||||
if has_method(robot, "teleop_safety_stop"):
|
||||
robot.teleop_safety_stop()
|
||||
|
||||
recorded_episodes = 0
|
||||
while True:
|
||||
if dataset["num_episodes"] >= num_episodes:
|
||||
if recorded_episodes >= num_episodes:
|
||||
break
|
||||
|
||||
episode_index = dataset["num_episodes"]
|
||||
log_say(f"Recording episode {episode_index}", play_sounds)
|
||||
# TODO(aliberts): add task prompt for multitask here. Might need to temporarily disable event if
|
||||
# input() messes with them.
|
||||
# if multi_task:
|
||||
# task = input("Enter your task description: ")
|
||||
|
||||
log_say(f"Recording episode {dataset.num_episodes}", play_sounds)
|
||||
record_episode(
|
||||
dataset=dataset,
|
||||
robot=robot,
|
||||
@@ -286,7 +300,7 @@ def record(
|
||||
# TODO(rcadene): add an option to enable teleoperation during reset
|
||||
# Skip reset for the last episode to be recorded
|
||||
if not events["stop_recording"] and (
|
||||
(episode_index < num_episodes - 1) or events["rerecord_episode"]
|
||||
(dataset.num_episodes < num_episodes - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment", play_sounds)
|
||||
reset_environment(robot, events, reset_time_s)
|
||||
@@ -295,11 +309,11 @@ def record(
|
||||
log_say("Re-record episode", play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
delete_current_episode(dataset)
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Increment by one dataset["current_episode_index"]
|
||||
save_current_episode(dataset)
|
||||
dataset.save_episode(task)
|
||||
recorded_episodes += 1
|
||||
|
||||
if events["stop_recording"]:
|
||||
break
|
||||
@@ -307,35 +321,42 @@ def record(
|
||||
log_say("Stop recording", play_sounds, blocking=True)
|
||||
stop_recording(robot, listener, display_cameras)
|
||||
|
||||
lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds)
|
||||
if run_compute_stats:
|
||||
logging.info("Computing dataset statistics")
|
||||
|
||||
dataset.consolidate(run_compute_stats)
|
||||
|
||||
if push_to_hub:
|
||||
dataset.push_to_hub(tags=tags)
|
||||
|
||||
log_say("Exiting", play_sounds)
|
||||
return lerobot_dataset
|
||||
return dataset
|
||||
|
||||
|
||||
@safe_disconnect
|
||||
def replay(
|
||||
robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug", play_sounds=True
|
||||
robot: Robot,
|
||||
root: Path,
|
||||
repo_id: str,
|
||||
episode: int,
|
||||
fps: int | None = None,
|
||||
play_sounds: bool = True,
|
||||
local_files_only: bool = True,
|
||||
):
|
||||
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
|
||||
# TODO(rcadene): Add option to record logs
|
||||
local_dir = Path(root) / repo_id
|
||||
if not local_dir.exists():
|
||||
raise ValueError(local_dir)
|
||||
|
||||
dataset = LeRobotDataset(repo_id, root=root)
|
||||
items = dataset.hf_dataset.select_columns("action")
|
||||
from_idx = dataset.episode_data_index["from"][episode].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode].item()
|
||||
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
log_say("Replaying episode", play_sounds, blocking=True)
|
||||
for idx in range(from_idx, to_idx):
|
||||
for idx in range(dataset.num_frames):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action = items[idx]["action"]
|
||||
action = actions[idx]["action"]
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
@@ -384,13 +405,25 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
parser_record = subparsers.add_parser("record", parents=[base_parser])
|
||||
task_args = parser_record.add_mutually_exclusive_group(required=True)
|
||||
parser_record.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
task_args.add_argument(
|
||||
"--single-task",
|
||||
type=str,
|
||||
help="A short but accurate description of the task performed during the recording.",
|
||||
)
|
||||
# TODO(aliberts): add multi-task support
|
||||
# task_args.add_argument(
|
||||
# "--multi-task",
|
||||
# type=int,
|
||||
# help="You will need to enter the task performed at the start of each episode.",
|
||||
# )
|
||||
parser_record.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
default="data",
|
||||
default=None,
|
||||
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
@@ -458,10 +491,10 @@ if __name__ == "__main__":
|
||||
),
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--force-override",
|
||||
"--resume",
|
||||
type=int,
|
||||
default=0,
|
||||
help="By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch.",
|
||||
help="Resume recording on an existing dataset.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"-p",
|
||||
@@ -486,7 +519,7 @@ if __name__ == "__main__":
|
||||
parser_replay.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
default="data",
|
||||
default=None,
|
||||
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
|
||||
)
|
||||
parser_replay.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user