From 95df341b4fe7e2bf1ffcd3ea571205fb59a86ff3 Mon Sep 17 00:00:00 2001 From: Simon Alibert <75076266+aliberts@users.noreply.github.com> Date: Fri, 6 Jun 2025 10:08:03 +0200 Subject: [PATCH] Fix LeKiwi example (#1217) --- examples/lekiwi/evaluate.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/examples/lekiwi/evaluate.py b/examples/lekiwi/evaluate.py index 5acff9b5..2a41440a 100644 --- a/examples/lekiwi/evaluate.py +++ b/examples/lekiwi/evaluate.py @@ -1,8 +1,6 @@ -import torch - +from lerobot.common.datasets.utils import build_dataset_frame, hw_to_dataset_features from lerobot.common.policies.act.modeling_act import ACTPolicy -from lerobot.common.robots.lekiwi.config_lekiwi import LeKiwiClientConfig -from lerobot.common.robots.lekiwi.lekiwi_client import LeKiwiClient +from lerobot.common.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig from lerobot.common.utils.control_utils import predict_action from lerobot.common.utils.utils import get_safe_torch_device @@ -16,22 +14,18 @@ robot.connect() policy = ACTPolicy.from_pretrained("pepijn223/act_lekiwi_circle") policy.reset() +obs_features = hw_to_dataset_features(robot.observation_features, "observation") + print("Running inference") i = 0 while i < NB_CYCLES_CLIENT_CONNECTION: obs = robot.get_observation() - for key, value in obs.items(): - if isinstance(value, torch.Tensor): - obs[key] = value.numpy() - + observation_frame = build_dataset_frame(obs_features, obs, prefix="observation") action_values = predict_action( - obs, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp + observation_frame, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp ) - action = { - key: action_values[i].item() if isinstance(action_values[i], torch.Tensor) else action_values[i] - for i, key in enumerate(robot.action_features) - } + action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)} robot.send_action(action) i += 1