This commit is contained in:
Remi Cadene
2024-09-30 15:22:13 +02:00
parent b2e5f7fe2d
commit ec1efc64b4
3 changed files with 129 additions and 19 deletions

View File

@@ -20,8 +20,10 @@ The majority of changes here involve removing unused code, unifying naming, and
"""
import math
from collections import deque
import threading
import time
from itertools import chain
from threading import Thread
from typing import Callable
import einops
@@ -36,6 +38,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.utils import TemporalQueue
class ACTPolicy(
@@ -87,22 +90,23 @@ class ACTPolicy(
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
self.reset()
self.thread = None
# TODO(rcadene): Add delta timestamps in policy
FPS = 10 # noqa: N806
self.delta_timestamps = [i / FPS for i in range(self.config.n_action_steps)]
def reset(self):
"""This should be called whenever the environment is reset."""
if self.config.temporal_ensemble_coeff is not None:
self.temporal_ensembler.reset()
else:
self._action_queue = deque([], maxlen=self.config.n_action_steps)
# TODO(rcadene): set proper maxlen
self._obs_queue = TemporalQueue(maxlen=1)
self._action_queue = TemporalQueue(maxlen=200)
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
def inference(self, batch: dict[str, Tensor]) -> Tensor:
self.eval()
batch = self.normalize_inputs(batch)
@@ -118,18 +122,47 @@ class ACTPolicy(
action = self.temporal_ensembler.update(actions)
return action
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._action_queue) == 0:
actions = self.model(batch)[0][:, : self.config.n_action_steps]
actions = self.model(batch)[0][:, : self.config.n_action_steps]
# TODO(rcadene): make _forward return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
# TODO(rcadene): make _forward return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
return actions
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def inference_loop(self):
prev_timestamp = None
while not self.stop_event.is_set():
last_observation, last_timestamp = self._obs_queue.get_latest()
if prev_timestamp is not None and prev_timestamp == last_timestamp:
# in case inference ran faster than recording/adding a new observation in the queue
time.sleep(0.1)
continue
pred_action_sequence = self.inference(last_observation)
for action, delta_ts in zip(pred_action_sequence, self.delta_timestamps, strict=False):
self._action_queue.add(action, last_timestamp + delta_ts)
prev_timestamp = last_timestamp
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
present_time = time.time()
self._obs_queue.add(batch, present_time)
if self.thread is None:
self.stop_event = threading.Event()
self.thread = Thread(target=self.inference_loop, args=())
self.thread.daemon = True
self.thread.start()
next_action = None
while next_action is None:
try:
next_action = self._action_queue.get(present_time)
except ValueError:
time.sleep(0.1) # no action available at this present time, we wait a bit
return next_action
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""