WIP
This commit is contained in:
@@ -20,8 +20,10 @@ The majority of changes here involve removing unused code, unifying naming, and
|
||||
"""
|
||||
|
||||
import math
|
||||
from collections import deque
|
||||
import threading
|
||||
import time
|
||||
from itertools import chain
|
||||
from threading import Thread
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
@@ -36,6 +38,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.utils import TemporalQueue
|
||||
|
||||
|
||||
class ACTPolicy(
|
||||
@@ -87,22 +90,23 @@ class ACTPolicy(
|
||||
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
|
||||
|
||||
self.reset()
|
||||
self.thread = None
|
||||
|
||||
# TODO(rcadene): Add delta timestamps in policy
|
||||
FPS = 10 # noqa: N806
|
||||
self.delta_timestamps = [i / FPS for i in range(self.config.n_action_steps)]
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
if self.config.temporal_ensemble_coeff is not None:
|
||||
self.temporal_ensembler.reset()
|
||||
else:
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
# TODO(rcadene): set proper maxlen
|
||||
self._obs_queue = TemporalQueue(maxlen=1)
|
||||
self._action_queue = TemporalQueue(maxlen=200)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
queue is empty.
|
||||
"""
|
||||
def inference(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
@@ -118,18 +122,47 @@ class ACTPolicy(
|
||||
action = self.temporal_ensembler.update(actions)
|
||||
return action
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.model(batch)[0][:, : self.config.n_action_steps]
|
||||
actions = self.model(batch)[0][:, : self.config.n_action_steps]
|
||||
|
||||
# TODO(rcadene): make _forward return output dictionary?
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
# TODO(rcadene): make _forward return output dictionary?
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
return actions
|
||||
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
def inference_loop(self):
|
||||
prev_timestamp = None
|
||||
while not self.stop_event.is_set():
|
||||
last_observation, last_timestamp = self._obs_queue.get_latest()
|
||||
|
||||
if prev_timestamp is not None and prev_timestamp == last_timestamp:
|
||||
# in case inference ran faster than recording/adding a new observation in the queue
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
pred_action_sequence = self.inference(last_observation)
|
||||
|
||||
for action, delta_ts in zip(pred_action_sequence, self.delta_timestamps, strict=False):
|
||||
self._action_queue.add(action, last_timestamp + delta_ts)
|
||||
|
||||
prev_timestamp = last_timestamp
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
present_time = time.time()
|
||||
self._obs_queue.add(batch, present_time)
|
||||
|
||||
if self.thread is None:
|
||||
self.stop_event = threading.Event()
|
||||
self.thread = Thread(target=self.inference_loop, args=())
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
next_action = None
|
||||
while next_action is None:
|
||||
try:
|
||||
next_action = self._action_queue.get(present_time)
|
||||
except ValueError:
|
||||
time.sleep(0.1) # no action available at this present time, we wait a bit
|
||||
|
||||
return next_action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
|
||||
Reference in New Issue
Block a user