Fix LeKiwi example (#1217)
This commit is contained in:
@@ -1,8 +1,6 @@
|
|||||||
import torch
|
from lerobot.common.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||||
|
|
||||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||||
from lerobot.common.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
|
from lerobot.common.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||||
from lerobot.common.robots.lekiwi.lekiwi_client import LeKiwiClient
|
|
||||||
from lerobot.common.utils.control_utils import predict_action
|
from lerobot.common.utils.control_utils import predict_action
|
||||||
from lerobot.common.utils.utils import get_safe_torch_device
|
from lerobot.common.utils.utils import get_safe_torch_device
|
||||||
|
|
||||||
@@ -16,22 +14,18 @@ robot.connect()
|
|||||||
policy = ACTPolicy.from_pretrained("pepijn223/act_lekiwi_circle")
|
policy = ACTPolicy.from_pretrained("pepijn223/act_lekiwi_circle")
|
||||||
policy.reset()
|
policy.reset()
|
||||||
|
|
||||||
|
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||||
|
|
||||||
print("Running inference")
|
print("Running inference")
|
||||||
i = 0
|
i = 0
|
||||||
while i < NB_CYCLES_CLIENT_CONNECTION:
|
while i < NB_CYCLES_CLIENT_CONNECTION:
|
||||||
obs = robot.get_observation()
|
obs = robot.get_observation()
|
||||||
|
|
||||||
for key, value in obs.items():
|
observation_frame = build_dataset_frame(obs_features, obs, prefix="observation")
|
||||||
if isinstance(value, torch.Tensor):
|
|
||||||
obs[key] = value.numpy()
|
|
||||||
|
|
||||||
action_values = predict_action(
|
action_values = predict_action(
|
||||||
obs, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
observation_frame, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
||||||
)
|
)
|
||||||
action = {
|
action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)}
|
||||||
key: action_values[i].item() if isinstance(action_values[i], torch.Tensor) else action_values[i]
|
|
||||||
for i, key in enumerate(robot.action_features)
|
|
||||||
}
|
|
||||||
robot.send_action(action)
|
robot.send_action(action)
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user