Fix predict_action from record

This commit is contained in:
Simon Alibert
2025-05-24 10:48:06 +02:00
parent 71e6520cd1
commit f4c11593d4
2 changed files with 8 additions and 3 deletions

View File

@@ -24,6 +24,7 @@ from contextlib import nullcontext
from copy import copy from copy import copy
from functools import cache from functools import cache
import numpy as np
import rerun as rr import rerun as rr
import torch import torch
from deepdiff import DeepDiff from deepdiff import DeepDiff
@@ -101,7 +102,9 @@ def is_headless():
return True return True
def predict_action(observation, policy, device, use_amp): def predict_action(
observation: dict[str, np.ndarray], policy: PreTrainedPolicy, device: torch.device, use_amp: bool
):
observation = copy(observation) observation = copy(observation)
with ( with (
torch.inference_mode(), torch.inference_mode(),

View File

@@ -178,9 +178,12 @@ def record_loop(
observation = robot.get_observation() observation = robot.get_observation()
if policy is not None or dataset is not None:
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
if policy is not None: if policy is not None:
action = predict_action( action = predict_action(
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp observation_frame, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
) )
else: else:
action = teleop.get_action() action = teleop.get_action()
@@ -190,7 +193,6 @@ def record_loop(
sent_action = robot.send_action(action) sent_action = robot.send_action(action)
if dataset is not None: if dataset is not None:
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action") action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
frame = {**observation_frame, **action_frame} frame = {**observation_frame, **action_frame}
dataset.add_frame(frame, task=single_task) dataset.add_frame(frame, task=single_task)