Add reset-time-s, Add keyboard early exit, Add comments

This commit is contained in:
Remi Cadene
2024-07-12 12:58:56 +02:00
parent 1993d29296
commit 7a659dbd6b
3 changed files with 244 additions and 18 deletions

View File

@@ -39,16 +39,24 @@ python lerobot/scripts/control_robot.py replay_episode \
--episode 0
```
- Record a full dataset in order to train a policy:
- Record a full dataset in order to train a policy, with 2 seconds of warmup,
30 seconds of recording for each episode, and 10 seconds to reset the environment in between episodes:
```bash
python lerobot/scripts/control_robot.py record_dataset \
--fps 30 \
--root data \
--repo-id $USER/koch_pick_place_lego \
--num-episodes 50 \
--run-compute-stats 1
--run-compute-stats 1 \
--warmup-time-s 2 \
--episode-time-s 30 \
--reset-time-s 10
```
**NOTE**: You can early exit while recording an episode or resetting the environment,
by tapping the right arrow key '->'. This might require a sudo permission
to allow your terminal to monitor keyboard events.
- Train on this dataset with the ACT policy:
```bash
DATA_DIR=data python lerobot/scripts/train.py \
@@ -77,6 +85,7 @@ from pathlib import Path
import torch
from omegaconf import DictConfig
from PIL import Image
from pynput import keyboard
from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
@@ -187,6 +196,7 @@ def record_dataset(
repo_id="lerobot/debug",
warmup_time_s=2,
episode_time_s=10,
reset_time_s=5,
num_episodes=50,
video=True,
run_compute_stats=True,
@@ -228,6 +238,20 @@ def record_dataset(
timestamp = time.perf_counter() - start_time
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
exit_early = False
def on_press(key):
nonlocal exit_early
if key == keyboard.Key.right:
print("Right arrow key pressed. Exiting loop...")
exit_early = True
listener = keyboard.Listener(on_press=on_press)
listener.start()
# Save images using threads to reach high fps (30 and more)
# Using `with` to exist smoothly if an execption is raised.
# Using only 4 worker threads to avoid blocking the main thread.
@@ -235,17 +259,13 @@ def record_dataset(
# Start recording all episodes
ep_dicts = []
for episode_index in range(num_episodes):
logging.info(f"Recording episode {episode_index}")
os.system(f'say "Recording episode {episode_index}" &')
ep_dict = {}
frame_index = 0
timestamp = 0
start_time = time.perf_counter()
is_record_print = False
while timestamp < episode_time_s:
if not is_record_print:
logging.info(f"Recording episode {episode_index}")
os.system(f'say "Recording episode {episode_index}" &')
is_record_print = True
now = time.perf_counter()
observation, action = robot.teleop_step(record_data=True)
@@ -275,6 +295,26 @@ def record_dataset(
timestamp = time.perf_counter() - start_time
if exit_early:
exit_early = False
break
# Skip resetting if 0 second allocated or it is the last episode
if reset_time_s == 0 or episode_index == num_episodes - 1:
continue
logging.info("Resetting environment")
os.system('say "Resetting environment" &')
timestamp = 0
start_time = time.perf_counter()
while timestamp < reset_time_s:
time.sleep(1)
timestamp = time.perf_counter() - start_time
if exit_early:
exit_early = False
break
num_frames = frame_index
for key in image_keys:
@@ -454,20 +494,61 @@ if __name__ == "__main__":
parser_record.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
)
parser_record.add_argument("--root", type=Path, default="data", help="")
parser_record.add_argument("--repo-id", type=str, default="lerobot/test", help="")
parser_record.add_argument("--warmup-time-s", type=int, default=2, help="")
parser_record.add_argument("--episode-time-s", type=int, default=10, help="")
parser_record.add_argument("--num-episodes", type=int, default=50, help="")
parser_record.add_argument("--run-compute-stats", type=int, default=1, help="")
parser_record.add_argument(
"--root",
type=Path,
default="data",
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
)
parser_record.add_argument(
"--repo-id",
type=str,
default="lerobot/test",
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
)
parser_record.add_argument(
"--warmup-time-s",
type=int,
default=2,
help="Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.",
)
parser_record.add_argument(
"--episode-time-s",
type=int,
default=10,
help="Number of seconds for data recording for each episode.",
)
parser_record.add_argument(
"--reset-time-s",
type=int,
default=5,
help="Number of seconds for resetting the environment after each episode.",
)
parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.")
parser_record.add_argument(
"--run-compute-stats",
type=int,
default=1,
help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.",
)
parser_replay = subparsers.add_parser("replay_episode", parents=[base_parser])
parser_replay.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
)
parser_replay.add_argument("--root", type=Path, default="data", help="")
parser_replay.add_argument("--repo-id", type=str, default="lerobot/test", help="")
parser_replay.add_argument("--episode", type=int, default=0, help="")
parser_replay.add_argument(
"--root",
type=Path,
default="data",
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
)
parser_replay.add_argument(
"--repo-id",
type=str,
default="lerobot/test",
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
)
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episode to replay.")
parser_policy = subparsers.add_parser("run_policy", parents=[base_parser])
parser_policy.add_argument(