Fix LeKiwi example (#1217)
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.common.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
|
||||
from lerobot.common.robots.lekiwi.lekiwi_client import LeKiwiClient
|
||||
from lerobot.common.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
from lerobot.common.utils.control_utils import predict_action
|
||||
from lerobot.common.utils.utils import get_safe_torch_device
|
||||
|
||||
@@ -16,22 +14,18 @@ robot.connect()
|
||||
policy = ACTPolicy.from_pretrained("pepijn223/act_lekiwi_circle")
|
||||
policy.reset()
|
||||
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
|
||||
print("Running inference")
|
||||
i = 0
|
||||
while i < NB_CYCLES_CLIENT_CONNECTION:
|
||||
obs = robot.get_observation()
|
||||
|
||||
for key, value in obs.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
obs[key] = value.numpy()
|
||||
|
||||
observation_frame = build_dataset_frame(obs_features, obs, prefix="observation")
|
||||
action_values = predict_action(
|
||||
obs, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
||||
observation_frame, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
||||
)
|
||||
action = {
|
||||
key: action_values[i].item() if isinstance(action_values[i], torch.Tensor) else action_values[i]
|
||||
for i, key in enumerate(robot.action_features)
|
||||
}
|
||||
action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)}
|
||||
robot.send_action(action)
|
||||
i += 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user