70 lines
2.3 KiB
Python
70 lines
2.3 KiB
Python
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
|
from lerobot.common.robot_devices.utils import busy_wait
|
|
import time
|
|
import argparse
|
|
from agilex_robot import AgilexRobot
|
|
import torch
|
|
|
|
def get_arguments():
|
|
parser = argparse.ArgumentParser()
|
|
args = parser.parse_args()
|
|
args.fps = 30
|
|
args.resume = False
|
|
args.repo_id = "tangger/test"
|
|
args.root = "./data2"
|
|
args.num_image_writer_processes = 0
|
|
args.num_image_writer_threads_per_camera = 4
|
|
args.video = True
|
|
args.num_episodes = 50
|
|
args.episode_time_s = 30000
|
|
args.play_sounds = False
|
|
args.display_cameras = True
|
|
args.single_task = "test test"
|
|
args.use_depth_image = False
|
|
args.use_base = False
|
|
args.push_to_hub = False
|
|
args.policy= None
|
|
args.teleoprate = False
|
|
return args
|
|
|
|
|
|
cfg = get_arguments()
|
|
robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
|
|
inference_time_s = 360
|
|
fps = 30
|
|
device = "cuda" # TODO: On Mac, use "mps" or "cpu"
|
|
|
|
ckpt_path = "/home/ubuntu/LYT/lerobot/outputs/train/act_move_tube_on_scale/checkpoints/last/pretrained_model"
|
|
policy = ACTPolicy.from_pretrained(ckpt_path)
|
|
policy.to(device)
|
|
|
|
for _ in range(inference_time_s * fps):
|
|
start_time = time.perf_counter()
|
|
|
|
# Read the follower state and access the frames from the cameras
|
|
observation = robot.capture_observation()
|
|
if observation is None:
|
|
print("Observation is None, skipping...")
|
|
continue
|
|
|
|
# Convert to pytorch format: channel first and float32 in [0,1]
|
|
# with batch dimension
|
|
for name in observation:
|
|
if "image" in name:
|
|
observation[name] = observation[name].type(torch.float32) / 255
|
|
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
|
observation[name] = observation[name].unsqueeze(0)
|
|
observation[name] = observation[name].to(device)
|
|
|
|
# Compute the next action with the policy
|
|
# based on the current observation
|
|
action = policy.select_action(observation)
|
|
# Remove batch dimension
|
|
action = action.squeeze(0)
|
|
# Move to cpu, if not already the case
|
|
action = action.to("cpu")
|
|
# Order the robot to move
|
|
robot.send_action(action)
|
|
|
|
dt_s = time.perf_counter() - start_time
|
|
busy_wait(1 / fps - dt_s) |