Fix predict_action from record

This commit is contained in:
Simon Alibert
2025-05-24 10:48:06 +02:00
parent 71e6520cd1
commit f4c11593d4
2 changed files with 8 additions and 3 deletions

View File

@@ -24,6 +24,7 @@ from contextlib import nullcontext
from copy import copy
from functools import cache
import numpy as np
import rerun as rr
import torch
from deepdiff import DeepDiff
@@ -101,7 +102,9 @@ def is_headless():
return True
def predict_action(observation, policy, device, use_amp):
def predict_action(
observation: dict[str, np.ndarray], policy: PreTrainedPolicy, device: torch.device, use_amp: bool
):
observation = copy(observation)
with (
torch.inference_mode(),

View File

@@ -178,9 +178,12 @@ def record_loop(
observation = robot.get_observation()
if policy is not None or dataset is not None:
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
if policy is not None:
action = predict_action(
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
observation_frame, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
)
else:
action = teleop.get_action()
@@ -190,7 +193,6 @@ def record_loop(
sent_action = robot.send_action(action)
if dataset is not None:
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
frame = {**observation_frame, **action_frame}
dataset.add_frame(frame, task=single_task)