Fix LeKiwi example (#1217)

This commit is contained in:
Simon Alibert
2025-06-06 10:08:03 +02:00
committed by GitHub
parent 9e6f49f507
commit 95df341b4f

View File

@@ -1,8 +1,6 @@
import torch from lerobot.common.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.common.policies.act.modeling_act import ACTPolicy from lerobot.common.policies.act.modeling_act import ACTPolicy
from lerobot.common.robots.lekiwi.config_lekiwi import LeKiwiClientConfig from lerobot.common.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
from lerobot.common.robots.lekiwi.lekiwi_client import LeKiwiClient
from lerobot.common.utils.control_utils import predict_action from lerobot.common.utils.control_utils import predict_action
from lerobot.common.utils.utils import get_safe_torch_device from lerobot.common.utils.utils import get_safe_torch_device
@@ -16,22 +14,18 @@ robot.connect()
policy = ACTPolicy.from_pretrained("pepijn223/act_lekiwi_circle") policy = ACTPolicy.from_pretrained("pepijn223/act_lekiwi_circle")
policy.reset() policy.reset()
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
print("Running inference") print("Running inference")
i = 0 i = 0
while i < NB_CYCLES_CLIENT_CONNECTION: while i < NB_CYCLES_CLIENT_CONNECTION:
obs = robot.get_observation() obs = robot.get_observation()
for key, value in obs.items(): observation_frame = build_dataset_frame(obs_features, obs, prefix="observation")
if isinstance(value, torch.Tensor):
obs[key] = value.numpy()
action_values = predict_action( action_values = predict_action(
obs, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp observation_frame, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
) )
action = { action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)}
key: action_values[i].item() if isinstance(action_values[i], torch.Tensor) else action_values[i]
for i, key in enumerate(robot.action_features)
}
robot.send_action(action) robot.send_action(action)
i += 1 i += 1