forked from tangger/lerobot
All tests passing except test_control_robot.py
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Example of usage:
|
||||
Examples of usage:
|
||||
|
||||
- Unlimited teleoperation at highest frequency (~200 Hz is expected), to exit with CTRL+C:
|
||||
```bash
|
||||
@@ -49,15 +49,19 @@ python lerobot/scripts/control_robot.py record_dataset \
|
||||
--run-compute-stats 1
|
||||
```
|
||||
|
||||
- Train on this dataset (TODO(rcadene)):
|
||||
- Train on this dataset with the ACT policy:
|
||||
```bash
|
||||
python lerobot/scripts/train.py
|
||||
DATA_DIR=data python lerobot/scripts/train.py \
|
||||
policy=act_koch_real \
|
||||
env=koch_real \
|
||||
dataset_repo_id=$USER/koch_pick_place_lego \
|
||||
hydra.run.dir=outputs/train/act_koch_real
|
||||
```
|
||||
|
||||
- Run the pretrained policy on the robot:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py run_policy \
|
||||
-p TODO(rcadene)
|
||||
-p outputs/train/act_koch_real/checkpoints/080000/pretrained_model
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -117,29 +121,37 @@ def log_control_info(robot, dt_s, episode_index=None, frame_index=None):
|
||||
log_items += [f"ep:{episode_index}"]
|
||||
if frame_index is not None:
|
||||
log_items += [f"frame:{frame_index}"]
|
||||
|
||||
# total step time displayed in milliseconds and its frequency
|
||||
log_items += [f"dt:{dt_s * 1000:5.2f}={1/ dt_s:3.1f}hz"]
|
||||
|
||||
def log_dt(shortname, dt_val_s):
|
||||
nonlocal log_items
|
||||
log_items += [f"{shortname}:{dt_val_s * 1000:5.2f}={1/ dt_val_s:3.1f}hz"]
|
||||
|
||||
# total step time displayed in milliseconds and its frequency
|
||||
log_dt("dt", dt_s)
|
||||
|
||||
for name in robot.leader_arms:
|
||||
read_dt_s = robot.logs[f'read_leader_{name}_pos_dt_s']
|
||||
log_items += [
|
||||
f"dtRlead{name[0]}:{read_dt_s * 1000:5.2f}={1/ read_dt_s:3.1f}hz",
|
||||
]
|
||||
key = f'read_leader_{name}_pos_dt_s'
|
||||
if key in robot.logs:
|
||||
log_dt("dtRlead", robot.logs[key])
|
||||
|
||||
for name in robot.follower_arms:
|
||||
write_dt_s = robot.logs[f'write_follower_{name}_goal_pos_dt_s']
|
||||
read_dt_s = robot.logs[f'read_follower_{name}_pos_dt_s']
|
||||
log_items += [
|
||||
f"dtRfoll{name[0]}:{write_dt_s * 1000:5.2f}={1/ write_dt_s:3.1f}hz",
|
||||
f"dtWfoll{name[0]}:{read_dt_s * 1000:5.2f}={1/ read_dt_s:3.1f}hz",
|
||||
]
|
||||
key = f'write_follower_{name}_goal_pos_dt_s'
|
||||
if key in robot.logs:
|
||||
log_dt("dtRfoll", robot.logs[key])
|
||||
|
||||
key = f'read_follower_{name}_pos_dt_s'
|
||||
if key in robot.logs:
|
||||
log_dt("dtWfoll", robot.logs[key])
|
||||
|
||||
for name in robot.cameras:
|
||||
read_dt_s = robot.logs[f"read_camera_{name}_dt_s"]
|
||||
async_read_dt_s = robot.logs[f"async_read_camera_{name}_dt_s"]
|
||||
log_items += [
|
||||
f"dtRcam{name[0]}:{read_dt_s * 1000:5.2f}={1/read_dt_s:3.1f}hz",
|
||||
f"dtARcam{name[0]}:{async_read_dt_s * 1000:5.2f}={1/async_read_dt_s:3.1f}hz",
|
||||
]
|
||||
key = f"read_camera_{name}_dt_s"
|
||||
if key in robot.logs:
|
||||
log_dt("dtRcam", robot.logs[key])
|
||||
|
||||
key = f"async_read_camera_{name}_dt_s"
|
||||
if key in robot.logs:
|
||||
log_dt("dtARcam", robot.logs[key])
|
||||
|
||||
logging.info(" ".join(log_items))
|
||||
|
||||
########################################################################################
|
||||
@@ -147,10 +159,12 @@ def log_control_info(robot, dt_s, episode_index=None, frame_index=None):
|
||||
########################################################################################
|
||||
|
||||
|
||||
def teleoperate(robot: Robot, fps: int | None = None):
|
||||
def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | None = None):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
start_time = time.perf_counter()
|
||||
while True:
|
||||
now = time.perf_counter()
|
||||
robot.teleop_step()
|
||||
@@ -162,6 +176,9 @@ def teleoperate(robot: Robot, fps: int | None = None):
|
||||
dt_s = time.perf_counter() - now
|
||||
log_control_info(robot, dt_s)
|
||||
|
||||
if teleop_time_s is not None and time.perf_counter() - start_time > teleop_time_s:
|
||||
break
|
||||
|
||||
|
||||
def record_dataset(
|
||||
robot: Robot,
|
||||
@@ -174,6 +191,8 @@ def record_dataset(
|
||||
video=True,
|
||||
run_compute_stats=True,
|
||||
):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
|
||||
if not video:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -327,8 +346,11 @@ def record_dataset(
|
||||
|
||||
# TODO(rcadene): push to hub
|
||||
|
||||
return lerobot_dataset
|
||||
|
||||
|
||||
def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug"):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
local_dir = Path(root) / repo_id
|
||||
if not local_dir.exists():
|
||||
raise ValueError(local_dir)
|
||||
@@ -357,7 +379,8 @@ def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="dat
|
||||
log_control_info(robot, dt_s)
|
||||
|
||||
|
||||
def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig):
|
||||
def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run_time_s: float | None = None):
|
||||
# TODO(rcadene): Add option to record eval dataset and logs
|
||||
policy.eval()
|
||||
|
||||
# Check device is available
|
||||
@@ -372,6 +395,7 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig):
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
start_time = time.perf_counter()
|
||||
while True:
|
||||
now = time.perf_counter()
|
||||
|
||||
@@ -391,6 +415,9 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig):
|
||||
dt_s = time.perf_counter() - now
|
||||
log_control_info(robot, dt_s)
|
||||
|
||||
if run_time_s is not None and time.perf_counter() - start_time > run_time_s:
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
Reference in New Issue
Block a user