forked from tangger/lerobot
Fix predict_action from record
This commit is contained in:
@@ -24,6 +24,7 @@ from contextlib import nullcontext
|
||||
from copy import copy
|
||||
from functools import cache
|
||||
|
||||
import numpy as np
|
||||
import rerun as rr
|
||||
import torch
|
||||
from deepdiff import DeepDiff
|
||||
@@ -101,7 +102,9 @@ def is_headless():
|
||||
return True
|
||||
|
||||
|
||||
def predict_action(observation, policy, device, use_amp):
|
||||
def predict_action(
|
||||
observation: dict[str, np.ndarray], policy: PreTrainedPolicy, device: torch.device, use_amp: bool
|
||||
):
|
||||
observation = copy(observation)
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
|
||||
@@ -178,9 +178,12 @@ def record_loop(
|
||||
|
||||
observation = robot.get_observation()
|
||||
|
||||
if policy is not None or dataset is not None:
|
||||
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
|
||||
|
||||
if policy is not None:
|
||||
action = predict_action(
|
||||
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
||||
observation_frame, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
||||
)
|
||||
else:
|
||||
action = teleop.get_action()
|
||||
@@ -190,7 +193,6 @@ def record_loop(
|
||||
sent_action = robot.send_action(action)
|
||||
|
||||
if dataset is not None:
|
||||
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
|
||||
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
|
||||
frame = {**observation_frame, **action_frame}
|
||||
dataset.add_frame(frame, task=single_task)
|
||||
|
||||
Reference in New Issue
Block a user