This commit is contained in:
Remi Cadene
2024-10-16 12:10:40 +02:00
parent a7841afaa4
commit a1a7f7887f
3 changed files with 52 additions and 76 deletions

View File

@@ -20,10 +20,9 @@ The majority of changes here involve removing unused code, unifying naming, and
""" """
import math import math
import threading
import time import time
from concurrent.futures import ThreadPoolExecutor
from itertools import chain from itertools import chain
from threading import Thread
from typing import Callable from typing import Callable
import einops import einops
@@ -38,7 +37,6 @@ from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ACTConfig from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.utils import TemporalQueue
class ACTPolicy( class ACTPolicy(
@@ -90,10 +88,9 @@ class ACTPolicy(
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size) self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
self.reset() self.reset()
self.thread = None
# TODO(rcadene): Add delta timestamps in policy # TODO(rcadene): Add delta timestamps in policy
FPS = 10 # noqa: N806 FPS = 30 # noqa: N806
self.delta_timestamps = { self.delta_timestamps = {
"action": [i / FPS for i in range(self.config.n_action_steps)], "action": [i / FPS for i in range(self.config.n_action_steps)],
} }
@@ -104,12 +101,11 @@ class ACTPolicy(
self.temporal_ensembler.reset() self.temporal_ensembler.reset()
else: else:
# TODO(rcadene): set proper maxlen # TODO(rcadene): set proper maxlen
self._obs_queue = TemporalQueue(maxlen=1) self.executor = None
self._action_seq_queue = TemporalQueue(maxlen=200) self.future = None
self._actions = None
self._action_sequence = None self._actions_timestamps = None
self._action_seq_index = 0 self._action_index = 0
self._action_seq_timestamp = None
@torch.no_grad @torch.no_grad
def inference(self, batch: dict[str, Tensor]) -> Tensor: def inference(self, batch: dict[str, Tensor]) -> Tensor:
@@ -134,77 +130,50 @@ class ACTPolicy(
actions = self.unnormalize_outputs({"action": actions})["action"] actions = self.unnormalize_outputs({"action": actions})["action"]
return actions return actions
def inference_loop(self): def inference_with_timestamp(self, batch, timestamp):
while not self.stop_event.is_set(): start_t = time.perf_counter()
with self.condition:
self.condition.wait()
start_t = time.perf_counter() actions = self.inference(batch)
last_observation, last_timestamp = self._obs_queue.get_latest() dt_s = time.perf_counter() - start_t
pred_action_sequence = self.inference(last_observation) print(f"Inference, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) -- {timestamp}")
self._action_seq_queue.add(pred_action_sequence, last_timestamp)
dt_s = time.perf_counter() - start_t return actions, timestamp
print(
f"Inference, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) -- {last_timestamp}"
) # , {next_action.mean().item()}")
self.new_action_seq_event.set()
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor:
present_time = time.time() present_timestamp = time.time()
self._obs_queue.add(batch, present_time)
if self.thread is None: if self.executor is None:
self.stop_event = threading.Event() self.executor = ThreadPoolExecutor(max_workers=1)
self.new_action_seq_event = threading.Event() self.future = self.executor.submit(self.inference_with_timestamp, batch, present_timestamp)
actions, inference_timestamp = self.future.result()
self._actions = actions
self._actions_timestamps = inference_timestamp + np.array(self.delta_timestamps["action"])
self.condition = threading.Condition() if self._action_index == 90:
self.thread = Thread(target=self.inference_loop, args=()) self.future = self.executor.submit(self.inference_with_timestamp, batch, present_timestamp)
self.thread.daemon = True
self.thread.start()
# Ask thread to run first inference if self._action_index >= self._actions.shape[1]:
with self.condition: actions, inference_timestamp = self.future.result()
self.condition.notify()
# Block main process until the thread ran it's first inference # find corresponding action_index in new actions
self.new_action_seq_event.wait() present_action_timestamp = self._actions_timestamps[-1]
self._action_sequence, self._action_seq_timestamp = self._action_seq_queue.get_latest() new_actions_timestamps = inference_timestamp + np.array(self.delta_timestamps["action"])
distances = np.abs(new_actions_timestamps - present_action_timestamp)
if self._action_seq_index == 97:
with self.condition:
self.condition.notify()
if self._action_seq_index >= len(self._action_sequence):
self.new_action_seq_event.wait()
latest_action_sequence, latest_seq_timestamp = self._action_seq_queue.get_latest()
# update sequence index
seq_timestamps = self._action_seq_timestamp + np.array(self.delta_timestamps["action"])
if self._action_seq_index == len(self.delta_timestamps["action"]):
current_timestamp = seq_timestamps[-1]
else:
current_timestamp = seq_timestamps[self._action_seq_index]
latest_seq_timestamps = latest_seq_timestamp + np.array(self.delta_timestamps["action"])
distances = np.abs(latest_seq_timestamps - current_timestamp)
nearest_idx = distances.argmin() nearest_idx = distances.argmin()
# update
self._action_index = nearest_idx
self._actions_timestamps = new_actions_timestamps
self._actions = actions
# TODO(rcadene): handle edge cases # TODO(rcadene): handle edge cases
self._action_seq_index = nearest_idx
# update action sequence action = self._actions[:, self._action_index]
self._action_sequence = latest_action_sequence present_action_timestamp = self._actions_timestamps[self._action_index]
# update inference timestamp (when this action sequence has been computed) self._action_index += 1
self._action_seq_timestamp = latest_seq_timestamp self._present_timestamp = present_timestamp
self._present_action_timestamp = present_action_timestamp
seq_timestamps = self._action_seq_timestamp + np.array(self.delta_timestamps["action"]) return action
current_timestamp = seq_timestamps[self._action_seq_index]
action = self._action_sequence[:, self._action_seq_index]
self._action_seq_index += 1
return action, present_time, current_timestamp
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation.""" """Run the batch through the model and compute the loss for training or validation."""

View File

@@ -3,17 +3,24 @@ import time
def busy_wait(seconds): def busy_wait(seconds):
if seconds <= 0:
return
if platform.system() == "Darwin": if platform.system() == "Darwin":
# On Mac, `time.sleep` is not accurate and we need to use this while loop trick, # On Mac, `time.sleep` is not accurate and we need to use this while loop trick,
# but it consumes CPU cycles. # but it consumes CPU cycles.
# TODO(rcadene): find an alternative: from python 11, time.sleep is precise # TODO(rcadene): find an alternative: from python 11, time.sleep is precise
end_time = time.perf_counter() + seconds
start_sleep = time.perf_counter()
time.sleep(seconds / 2)
dt_sleep = time.perf_counter() - start_sleep
end_time = time.perf_counter() + (seconds - dt_sleep)
while time.perf_counter() < end_time: while time.perf_counter() < end_time:
pass pass
else: else:
# On Linux time.sleep is accurate # On Linux time.sleep is accurate
if seconds > 0: time.sleep(seconds)
time.sleep(seconds)
class RobotDeviceNotConnectedError(Exception): class RobotDeviceNotConnectedError(Exception):

View File

@@ -40,17 +40,17 @@ def main(env_name, policy_name, extra_overrides):
fps = 30 fps = 30
for i in range(200): for i in range(400):
start_loop_t = time.perf_counter() start_loop_t = time.perf_counter()
next_action, present_time, action_ts = policy.select_action(obs) next_action = policy.select_action(obs) # noqa: F841
dt_s = time.perf_counter() - start_loop_t dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s) busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - start_loop_t dt_s = time.perf_counter() - start_loop_t
print( print(
f"{i=}, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) \t{present_time}\t{action_ts}" f"{i=}, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) \t{policy._present_timestamp}\t{policy._present_action_timestamp}"
) # , {next_action.mean().item()}") ) # , {next_action.mean().item()}")
# time.sleep(1/30) # frequency at which we receive a new observation (30 Hz = 0.03 s) # time.sleep(1/30) # frequency at which we receive a new observation (30 Hz = 0.03 s)