Fix LeKiwi example (#1217)

This commit is contained in:
Simon Alibert
2025-06-06 10:08:03 +02:00
committed by GitHub
parent 9e6f49f507
commit 95df341b4f

View File

@@ -1,8 +1,6 @@
import torch
from lerobot.common.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.common.policies.act.modeling_act import ACTPolicy
from lerobot.common.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
from lerobot.common.robots.lekiwi.lekiwi_client import LeKiwiClient
from lerobot.common.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
from lerobot.common.utils.control_utils import predict_action
from lerobot.common.utils.utils import get_safe_torch_device
@@ -16,22 +14,18 @@ robot.connect()
policy = ACTPolicy.from_pretrained("pepijn223/act_lekiwi_circle")
policy.reset()
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
print("Running inference")
i = 0
while i < NB_CYCLES_CLIENT_CONNECTION:
obs = robot.get_observation()
for key, value in obs.items():
if isinstance(value, torch.Tensor):
obs[key] = value.numpy()
observation_frame = build_dataset_frame(obs_features, obs, prefix="observation")
action_values = predict_action(
obs, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
observation_frame, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
)
action = {
key: action_values[i].item() if isinstance(action_values[i], torch.Tensor) else action_values[i]
for i, key in enumerate(robot.action_features)
}
action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)}
robot.send_action(action)
i += 1