All tests passing except test_control_robot.py

This commit is contained in:
Remi Cadene
2024-07-09 22:53:39 +02:00
parent a0432f1608
commit 798373e7bf
14 changed files with 493 additions and 168 deletions

View File

@@ -1,5 +1,5 @@
"""
Example of usage:
Examples of usage:
- Unlimited teleoperation at highest frequency (~200 Hz is expected), to exit with CTRL+C:
```bash
@@ -49,15 +49,19 @@ python lerobot/scripts/control_robot.py record_dataset \
--run-compute-stats 1
```
- Train on this dataset (TODO(rcadene)):
- Train on this dataset with the ACT policy:
```bash
python lerobot/scripts/train.py
DATA_DIR=data python lerobot/scripts/train.py \
policy=act_koch_real \
env=koch_real \
dataset_repo_id=$USER/koch_pick_place_lego \
hydra.run.dir=outputs/train/act_koch_real
```
- Run the pretrained policy on the robot:
```bash
python lerobot/scripts/control_robot.py run_policy \
-p TODO(rcadene)
-p outputs/train/act_koch_real/checkpoints/080000/pretrained_model
```
"""
@@ -117,29 +121,37 @@ def log_control_info(robot, dt_s, episode_index=None, frame_index=None):
log_items += [f"ep:{episode_index}"]
if frame_index is not None:
log_items += [f"frame:{frame_index}"]
# total step time displayed in milliseconds and its frequency
log_items += [f"dt:{dt_s * 1000:5.2f}={1/ dt_s:3.1f}hz"]
def log_dt(shortname, dt_val_s):
nonlocal log_items
log_items += [f"{shortname}:{dt_val_s * 1000:5.2f}={1/ dt_val_s:3.1f}hz"]
# total step time displayed in milliseconds and its frequency
log_dt("dt", dt_s)
for name in robot.leader_arms:
read_dt_s = robot.logs[f'read_leader_{name}_pos_dt_s']
log_items += [
f"dtRlead{name[0]}:{read_dt_s * 1000:5.2f}={1/ read_dt_s:3.1f}hz",
]
key = f'read_leader_{name}_pos_dt_s'
if key in robot.logs:
log_dt("dtRlead", robot.logs[key])
for name in robot.follower_arms:
write_dt_s = robot.logs[f'write_follower_{name}_goal_pos_dt_s']
read_dt_s = robot.logs[f'read_follower_{name}_pos_dt_s']
log_items += [
f"dtRfoll{name[0]}:{write_dt_s * 1000:5.2f}={1/ write_dt_s:3.1f}hz",
f"dtWfoll{name[0]}:{read_dt_s * 1000:5.2f}={1/ read_dt_s:3.1f}hz",
]
key = f'write_follower_{name}_goal_pos_dt_s'
if key in robot.logs:
log_dt("dtRfoll", robot.logs[key])
key = f'read_follower_{name}_pos_dt_s'
if key in robot.logs:
log_dt("dtWfoll", robot.logs[key])
for name in robot.cameras:
read_dt_s = robot.logs[f"read_camera_{name}_dt_s"]
async_read_dt_s = robot.logs[f"async_read_camera_{name}_dt_s"]
log_items += [
f"dtRcam{name[0]}:{read_dt_s * 1000:5.2f}={1/read_dt_s:3.1f}hz",
f"dtARcam{name[0]}:{async_read_dt_s * 1000:5.2f}={1/async_read_dt_s:3.1f}hz",
]
key = f"read_camera_{name}_dt_s"
if key in robot.logs:
log_dt("dtRcam", robot.logs[key])
key = f"async_read_camera_{name}_dt_s"
if key in robot.logs:
log_dt("dtARcam", robot.logs[key])
logging.info(" ".join(log_items))
########################################################################################
@@ -147,10 +159,12 @@ def log_control_info(robot, dt_s, episode_index=None, frame_index=None):
########################################################################################
def teleoperate(robot: Robot, fps: int | None = None):
def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | None = None):
# TODO(rcadene): Add option to record logs
if not robot.is_connected:
robot.connect()
start_time = time.perf_counter()
while True:
now = time.perf_counter()
robot.teleop_step()
@@ -162,6 +176,9 @@ def teleoperate(robot: Robot, fps: int | None = None):
dt_s = time.perf_counter() - now
log_control_info(robot, dt_s)
if teleop_time_s is not None and time.perf_counter() - start_time > teleop_time_s:
break
def record_dataset(
robot: Robot,
@@ -174,6 +191,8 @@ def record_dataset(
video=True,
run_compute_stats=True,
):
# TODO(rcadene): Add option to record logs
if not video:
raise NotImplementedError()
@@ -327,8 +346,11 @@ def record_dataset(
# TODO(rcadene): push to hub
return lerobot_dataset
def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug"):
# TODO(rcadene): Add option to record logs
local_dir = Path(root) / repo_id
if not local_dir.exists():
raise ValueError(local_dir)
@@ -357,7 +379,8 @@ def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="dat
log_control_info(robot, dt_s)
def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig):
def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run_time_s: float | None = None):
# TODO(rcadene): Add option to record eval dataset and logs
policy.eval()
# Check device is available
@@ -372,6 +395,7 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig):
if not robot.is_connected:
robot.connect()
start_time = time.perf_counter()
while True:
now = time.perf_counter()
@@ -391,6 +415,9 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig):
dt_s = time.perf_counter() - now
log_control_info(robot, dt_s)
if run_time_s is not None and time.perf_counter() - start_time > run_time_s:
break
if __name__ == "__main__":
parser = argparse.ArgumentParser()