From f4c11593d4ea79070b28a77731e726167f156280 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Sat, 24 May 2025 10:48:06 +0200 Subject: [PATCH] Fix predict_action from record --- lerobot/common/utils/control_utils.py | 5 ++++- lerobot/record.py | 6 ++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/lerobot/common/utils/control_utils.py b/lerobot/common/utils/control_utils.py index d4790a57..012f9f27 100644 --- a/lerobot/common/utils/control_utils.py +++ b/lerobot/common/utils/control_utils.py @@ -24,6 +24,7 @@ from contextlib import nullcontext from copy import copy from functools import cache +import numpy as np import rerun as rr import torch from deepdiff import DeepDiff @@ -101,7 +102,9 @@ def is_headless(): return True -def predict_action(observation, policy, device, use_amp): +def predict_action( + observation: dict[str, np.ndarray], policy: PreTrainedPolicy, device: torch.device, use_amp: bool +): observation = copy(observation) with ( torch.inference_mode(), diff --git a/lerobot/record.py b/lerobot/record.py index 9d07f88b..6ddeb23b 100644 --- a/lerobot/record.py +++ b/lerobot/record.py @@ -178,9 +178,12 @@ def record_loop( observation = robot.get_observation() + if policy is not None or dataset is not None: + observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation") + if policy is not None: action = predict_action( - observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp + observation_frame, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp ) else: action = teleop.get_action() @@ -190,7 +193,6 @@ def record_loop( sent_action = robot.send_action(action) if dataset is not None: - observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation") action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action") frame = {**observation_frame, **action_frame} dataset.add_frame(frame, task=single_task)