Fix predict_action from record

This commit is contained in:
Simon Alibert
2025-05-24 10:48:06 +02:00
parent 71e6520cd1
commit f4c11593d4
2 changed files with 8 additions and 3 deletions

View File

@@ -24,6 +24,7 @@ from contextlib import nullcontext
from copy import copy
from functools import cache
import numpy as np
import rerun as rr
import torch
from deepdiff import DeepDiff
@@ -101,7 +102,9 @@ def is_headless():
return True
def predict_action(observation, policy, device, use_amp):
def predict_action(
observation: dict[str, np.ndarray], policy: PreTrainedPolicy, device: torch.device, use_amp: bool
):
observation = copy(observation)
with (
torch.inference_mode(),