forked from tangger/lerobot
Style
This commit is contained in:
@@ -62,20 +62,22 @@ python lerobot/scripts/control_robot.py run_policy \
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
import concurrent.futures
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
from omegaconf import DictConfig
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.common.datasets.compute_stats import compute_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
|
||||
from lerobot.common.datasets.utils import calculate_episode_data_index, load_hf_dataset
|
||||
from lerobot.common.datasets.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
@@ -83,14 +85,12 @@ from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed
|
||||
from lerobot.scripts.eval import get_pretrained_policy_path
|
||||
from lerobot.scripts.push_dataset_to_hub import save_meta_data
|
||||
from lerobot.scripts.robot_controls.record_dataset import record_dataset
|
||||
import concurrent.futures
|
||||
|
||||
|
||||
########################################################################################
|
||||
# Utilities
|
||||
########################################################################################
|
||||
|
||||
|
||||
def save_image(img_tensor, key, frame_index, episode_index, videos_dir):
|
||||
img = Image.fromarray(img_tensor.numpy())
|
||||
path = videos_dir / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png"
|
||||
@@ -106,15 +106,18 @@ def busy_wait(seconds):
|
||||
while time.perf_counter() < end_time:
|
||||
pass
|
||||
|
||||
|
||||
def none_or_int(value):
|
||||
if value == 'None':
|
||||
if value == "None":
|
||||
return None
|
||||
return int(value)
|
||||
|
||||
|
||||
########################################################################################
|
||||
# Control modes
|
||||
########################################################################################
|
||||
|
||||
|
||||
def teleoperate(robot: Robot, fps: int | None = None):
|
||||
robot.init_teleop()
|
||||
|
||||
@@ -123,14 +126,24 @@ def teleoperate(robot: Robot, fps: int | None = None):
|
||||
robot.teleop_step()
|
||||
|
||||
if fps is not None:
|
||||
dt_s = (time.perf_counter() - now)
|
||||
dt_s = time.perf_counter() - now
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = (time.perf_counter() - now)
|
||||
dt_s = time.perf_counter() - now
|
||||
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
|
||||
|
||||
|
||||
def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="lerobot/debug", warmup_time_s=2, episode_time_s=10, num_episodes=50, video=True, run_compute_stats=True):
|
||||
def record_dataset(
|
||||
robot: Robot,
|
||||
fps: int | None = None,
|
||||
root="data",
|
||||
repo_id="lerobot/debug",
|
||||
warmup_time_s=2,
|
||||
episode_time_s=10,
|
||||
num_episodes=50,
|
||||
video=True,
|
||||
run_compute_stats=True,
|
||||
):
|
||||
if not video:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -143,7 +156,6 @@ def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="l
|
||||
videos_dir = local_dir / "videos"
|
||||
videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
is_warmup_print = False
|
||||
@@ -154,7 +166,6 @@ def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="l
|
||||
# Using `with` ensures the program exists smoothly if an execption is raised.
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
for episode_index in range(num_episodes):
|
||||
|
||||
ep_dict = {}
|
||||
frame_index = 0
|
||||
|
||||
@@ -169,10 +180,10 @@ def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="l
|
||||
timestamp = time.perf_counter() - start_time
|
||||
|
||||
if timestamp < warmup_time_s:
|
||||
dt_s = (time.perf_counter() - now)
|
||||
dt_s = time.perf_counter() - now
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = (time.perf_counter() - now)
|
||||
dt_s = time.perf_counter() - now
|
||||
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f} (Warmup)")
|
||||
continue
|
||||
|
||||
@@ -199,10 +210,10 @@ def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="l
|
||||
|
||||
frame_index += 1
|
||||
|
||||
dt_s = (time.perf_counter() - now)
|
||||
dt_s = time.perf_counter() - now
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = (time.perf_counter() - now)
|
||||
dt_s = time.perf_counter() - now
|
||||
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
|
||||
|
||||
if timestamp > episode_time_s - warmup_time_s:
|
||||
@@ -229,7 +240,7 @@ def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="l
|
||||
|
||||
for key in not_image_keys:
|
||||
ep_dict[key] = torch.stack(ep_dict[key])
|
||||
|
||||
|
||||
for key in action:
|
||||
ep_dict[key] = torch.stack(ep_dict[key])
|
||||
|
||||
@@ -269,10 +280,7 @@ def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="l
|
||||
info=info,
|
||||
videos_dir=videos_dir,
|
||||
)
|
||||
if run_compute_stats:
|
||||
stats = compute_stats(lerobot_dataset)
|
||||
else:
|
||||
stats = {}
|
||||
stats = compute_stats(lerobot_dataset) if run_compute_stats else {}
|
||||
|
||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||
hf_dataset.save_to_disk(str(local_dir / "train"))
|
||||
@@ -293,7 +301,7 @@ def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="dat
|
||||
to_idx = dataset.episode_data_index["to"][episode].item()
|
||||
|
||||
robot.init_teleop()
|
||||
|
||||
|
||||
print("Replaying episode")
|
||||
os.system('say "Replaying episode"')
|
||||
|
||||
@@ -303,10 +311,10 @@ def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="dat
|
||||
action = items[idx]["action"]
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = (time.perf_counter() - now)
|
||||
dt_s = time.perf_counter() - now
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = (time.perf_counter() - now)
|
||||
dt_s = time.perf_counter() - now
|
||||
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
|
||||
|
||||
|
||||
@@ -327,15 +335,18 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig):
|
||||
|
||||
observation = robot.capture_observation()
|
||||
|
||||
with torch.inference_mode(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(),
|
||||
):
|
||||
action = policy.select_action(observation)
|
||||
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = (time.perf_counter() - now)
|
||||
dt_s = time.perf_counter() - now
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = (time.perf_counter() - now)
|
||||
dt_s = time.perf_counter() - now
|
||||
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
|
||||
|
||||
|
||||
@@ -345,32 +356,46 @@ if __name__ == "__main__":
|
||||
|
||||
# Set common options for all the subparsers
|
||||
base_parser = argparse.ArgumentParser(add_help=False)
|
||||
base_parser.add_argument("--robot", type=str, default="koch", help="Name of the robot provided to the `make_robot(name)` factory function.")
|
||||
base_parser.add_argument(
|
||||
"--robot",
|
||||
type=str,
|
||||
default="koch",
|
||||
help="Name of the robot provided to the `make_robot(name)` factory function.",
|
||||
)
|
||||
|
||||
parser_teleop = subparsers.add_parser("teleoperate", parents=[base_parser])
|
||||
parser_teleop.add_argument('--fps', type=none_or_int, default=None, help='Frames per second (set to None to disable)')
|
||||
parser_teleop.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
|
||||
parser_record = subparsers.add_parser("record_dataset", parents=[base_parser])
|
||||
parser_record.add_argument('--fps', type=none_or_int, default=None, help='Frames per second (set to None to disable)')
|
||||
parser_record.add_argument('--root', type=Path, default="data", help='')
|
||||
parser_record.add_argument('--repo-id', type=str, default="lerobot/test", help='')
|
||||
parser_record.add_argument('--warmup-time-s', type=int, default=2, help='')
|
||||
parser_record.add_argument('--episode-time-s', type=int, default=10, help='')
|
||||
parser_record.add_argument('--num-episodes', type=int, default=50, help='')
|
||||
parser_record.add_argument('--run-compute-stats', type=int, default=1, help='')
|
||||
parser_record.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
parser_record.add_argument("--root", type=Path, default="data", help="")
|
||||
parser_record.add_argument("--repo-id", type=str, default="lerobot/test", help="")
|
||||
parser_record.add_argument("--warmup-time-s", type=int, default=2, help="")
|
||||
parser_record.add_argument("--episode-time-s", type=int, default=10, help="")
|
||||
parser_record.add_argument("--num-episodes", type=int, default=50, help="")
|
||||
parser_record.add_argument("--run-compute-stats", type=int, default=1, help="")
|
||||
|
||||
parser_replay = subparsers.add_parser("replay_episode", parents=[base_parser])
|
||||
parser_replay.add_argument('--fps', type=none_or_int, default=None, help='Frames per second (set to None to disable)')
|
||||
parser_replay.add_argument('--root', type=Path, default="data", help='')
|
||||
parser_replay.add_argument('--repo-id', type=str, default="lerobot/test", help='')
|
||||
parser_replay.add_argument('--episode', type=int, default=0, help='')
|
||||
parser_replay.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
parser_replay.add_argument("--root", type=Path, default="data", help="")
|
||||
parser_replay.add_argument("--repo-id", type=str, default="lerobot/test", help="")
|
||||
parser_replay.add_argument("--episode", type=int, default=0, help="")
|
||||
|
||||
parser_policy = subparsers.add_parser("run_policy", parents=[base_parser])
|
||||
parser_policy.add_argument('-p', '--pretrained-policy-name-or-path', type=str,
|
||||
parser_policy.add_argument(
|
||||
"-p",
|
||||
"--pretrained-policy-name-or-path",
|
||||
type=str,
|
||||
help=(
|
||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||
"saved using `Policy.save_pretrained`."
|
||||
)
|
||||
),
|
||||
)
|
||||
parser_policy.add_argument(
|
||||
"overrides",
|
||||
|
||||
@@ -580,9 +580,7 @@ def main(
|
||||
|
||||
def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None):
|
||||
try:
|
||||
pretrained_policy_path = Path(
|
||||
snapshot_download(pretrained_policy_name_or_path, revision=revision)
|
||||
)
|
||||
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name_or_path, revision=revision))
|
||||
except (HFValidationError, RepositoryNotFoundError) as e:
|
||||
if isinstance(e, HFValidationError):
|
||||
error_message = (
|
||||
@@ -644,7 +642,9 @@ if __name__ == "__main__":
|
||||
if args.pretrained_policy_name_or_path is None:
|
||||
main(hydra_cfg_path=args.config, out_dir=args.out_dir, config_overrides=args.overrides)
|
||||
else:
|
||||
pretrained_policy_path = get_pretrained_policy_path(args.pretrained_policy_name_or_path, revision=args.revision)
|
||||
pretrained_policy_path = get_pretrained_policy_path(
|
||||
args.pretrained_policy_name_or_path, revision=args.revision
|
||||
)
|
||||
|
||||
main(
|
||||
pretrained_policy_path=pretrained_policy_path,
|
||||
|
||||
Reference in New Issue
Block a user