Fix predict_action from record
This commit is contained in:
@@ -24,6 +24,7 @@ from contextlib import nullcontext
|
|||||||
from copy import copy
|
from copy import copy
|
||||||
from functools import cache
|
from functools import cache
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import rerun as rr
|
import rerun as rr
|
||||||
import torch
|
import torch
|
||||||
from deepdiff import DeepDiff
|
from deepdiff import DeepDiff
|
||||||
@@ -101,7 +102,9 @@ def is_headless():
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def predict_action(observation, policy, device, use_amp):
|
def predict_action(
|
||||||
|
observation: dict[str, np.ndarray], policy: PreTrainedPolicy, device: torch.device, use_amp: bool
|
||||||
|
):
|
||||||
observation = copy(observation)
|
observation = copy(observation)
|
||||||
with (
|
with (
|
||||||
torch.inference_mode(),
|
torch.inference_mode(),
|
||||||
|
|||||||
@@ -178,9 +178,12 @@ def record_loop(
|
|||||||
|
|
||||||
observation = robot.get_observation()
|
observation = robot.get_observation()
|
||||||
|
|
||||||
|
if policy is not None or dataset is not None:
|
||||||
|
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
|
||||||
|
|
||||||
if policy is not None:
|
if policy is not None:
|
||||||
action = predict_action(
|
action = predict_action(
|
||||||
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
observation_frame, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
action = teleop.get_action()
|
action = teleop.get_action()
|
||||||
@@ -190,7 +193,6 @@ def record_loop(
|
|||||||
sent_action = robot.send_action(action)
|
sent_action = robot.send_action(action)
|
||||||
|
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
|
|
||||||
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
|
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
|
||||||
frame = {**observation_frame, **action_frame}
|
frame = {**observation_frame, **action_frame}
|
||||||
dataset.add_frame(frame, task=single_task)
|
dataset.add_frame(frame, task=single_task)
|
||||||
|
|||||||
Reference in New Issue
Block a user