diff --git a/lerobot/common/utils/control_utils.py b/lerobot/common/utils/control_utils.py index d4790a57..012f9f27 100644 --- a/lerobot/common/utils/control_utils.py +++ b/lerobot/common/utils/control_utils.py @@ -24,6 +24,7 @@ from contextlib import nullcontext from copy import copy from functools import cache +import numpy as np import rerun as rr import torch from deepdiff import DeepDiff @@ -101,7 +102,9 @@ def is_headless(): return True -def predict_action(observation, policy, device, use_amp): +def predict_action( + observation: dict[str, np.ndarray], policy: PreTrainedPolicy, device: torch.device, use_amp: bool +): observation = copy(observation) with ( torch.inference_mode(), diff --git a/lerobot/record.py b/lerobot/record.py index 9d07f88b..6ddeb23b 100644 --- a/lerobot/record.py +++ b/lerobot/record.py @@ -178,9 +178,12 @@ def record_loop( observation = robot.get_observation() + if policy is not None or dataset is not None: + observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation") + if policy is not None: action = predict_action( - observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp + observation_frame, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp ) else: action = teleop.get_action() @@ -190,7 +193,6 @@ def record_loop( sent_action = robot.send_action(action) if dataset is not None: - observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation") action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action") frame = {**observation_frame, **action_frame} dataset.add_frame(frame, task=single_task)