Fix predict_action from record
This commit is contained in:
@@ -24,6 +24,7 @@ from contextlib import nullcontext
|
||||
from copy import copy
|
||||
from functools import cache
|
||||
|
||||
import numpy as np
|
||||
import rerun as rr
|
||||
import torch
|
||||
from deepdiff import DeepDiff
|
||||
@@ -101,7 +102,9 @@ def is_headless():
|
||||
return True
|
||||
|
||||
|
||||
def predict_action(observation, policy, device, use_amp):
|
||||
def predict_action(
|
||||
observation: dict[str, np.ndarray], policy: PreTrainedPolicy, device: torch.device, use_amp: bool
|
||||
):
|
||||
observation = copy(observation)
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
|
||||
Reference in New Issue
Block a user