diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index fd7cfa7ac..3397d8cd3 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -20,10 +20,9 @@ The majority of changes here involve removing unused code, unifying naming, and """ import math -import threading import time +from concurrent.futures import ThreadPoolExecutor from itertools import chain -from threading import Thread from typing import Callable import einops @@ -38,7 +37,6 @@ from torchvision.ops.misc import FrozenBatchNorm2d from lerobot.common.policies.act.configuration_act import ACTConfig from lerobot.common.policies.normalize import Normalize, Unnormalize -from lerobot.common.policies.utils import TemporalQueue class ACTPolicy( @@ -90,10 +88,9 @@ class ACTPolicy( self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size) self.reset() - self.thread = None # TODO(rcadene): Add delta timestamps in policy - FPS = 10 # noqa: N806 + FPS = 30 # noqa: N806 self.delta_timestamps = { "action": [i / FPS for i in range(self.config.n_action_steps)], } @@ -104,12 +101,11 @@ class ACTPolicy( self.temporal_ensembler.reset() else: # TODO(rcadene): set proper maxlen - self._obs_queue = TemporalQueue(maxlen=1) - self._action_seq_queue = TemporalQueue(maxlen=200) - - self._action_sequence = None - self._action_seq_index = 0 - self._action_seq_timestamp = None + self.executor = None + self.future = None + self._actions = None + self._actions_timestamps = None + self._action_index = 0 @torch.no_grad def inference(self, batch: dict[str, Tensor]) -> Tensor: @@ -134,77 +130,50 @@ class ACTPolicy( actions = self.unnormalize_outputs({"action": actions})["action"] return actions - def inference_loop(self): - while not self.stop_event.is_set(): - with self.condition: - self.condition.wait() + def inference_with_timestamp(self, batch, timestamp): + start_t = time.perf_counter() - start_t = time.perf_counter() + actions = self.inference(batch) - last_observation, last_timestamp = self._obs_queue.get_latest() - pred_action_sequence = self.inference(last_observation) - self._action_seq_queue.add(pred_action_sequence, last_timestamp) + dt_s = time.perf_counter() - start_t + print(f"Inference, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) -- {timestamp}") - dt_s = time.perf_counter() - start_t - print( - f"Inference, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) -- {last_timestamp}" - ) # , {next_action.mean().item()}") - - self.new_action_seq_event.set() + return actions, timestamp def select_action(self, batch: dict[str, Tensor]) -> Tensor: - present_time = time.time() - self._obs_queue.add(batch, present_time) + present_timestamp = time.time() - if self.thread is None: - self.stop_event = threading.Event() - self.new_action_seq_event = threading.Event() + if self.executor is None: + self.executor = ThreadPoolExecutor(max_workers=1) + self.future = self.executor.submit(self.inference_with_timestamp, batch, present_timestamp) + actions, inference_timestamp = self.future.result() + self._actions = actions + self._actions_timestamps = inference_timestamp + np.array(self.delta_timestamps["action"]) - self.condition = threading.Condition() - self.thread = Thread(target=self.inference_loop, args=()) - self.thread.daemon = True - self.thread.start() + if self._action_index == 90: + self.future = self.executor.submit(self.inference_with_timestamp, batch, present_timestamp) - # Ask thread to run first inference - with self.condition: - self.condition.notify() + if self._action_index >= self._actions.shape[1]: + actions, inference_timestamp = self.future.result() - # Block main process until the thread ran it's first inference - self.new_action_seq_event.wait() - self._action_sequence, self._action_seq_timestamp = self._action_seq_queue.get_latest() - - if self._action_seq_index == 97: - with self.condition: - self.condition.notify() - - if self._action_seq_index >= len(self._action_sequence): - self.new_action_seq_event.wait() - latest_action_sequence, latest_seq_timestamp = self._action_seq_queue.get_latest() - - # update sequence index - seq_timestamps = self._action_seq_timestamp + np.array(self.delta_timestamps["action"]) - if self._action_seq_index == len(self.delta_timestamps["action"]): - current_timestamp = seq_timestamps[-1] - else: - current_timestamp = seq_timestamps[self._action_seq_index] - - latest_seq_timestamps = latest_seq_timestamp + np.array(self.delta_timestamps["action"]) - distances = np.abs(latest_seq_timestamps - current_timestamp) + # find corresponding action_index in new actions + present_action_timestamp = self._actions_timestamps[-1] + new_actions_timestamps = inference_timestamp + np.array(self.delta_timestamps["action"]) + distances = np.abs(new_actions_timestamps - present_action_timestamp) nearest_idx = distances.argmin() + + # update + self._action_index = nearest_idx + self._actions_timestamps = new_actions_timestamps + self._actions = actions # TODO(rcadene): handle edge cases - self._action_seq_index = nearest_idx - # update action sequence - self._action_sequence = latest_action_sequence - # update inference timestamp (when this action sequence has been computed) - self._action_seq_timestamp = latest_seq_timestamp - - seq_timestamps = self._action_seq_timestamp + np.array(self.delta_timestamps["action"]) - current_timestamp = seq_timestamps[self._action_seq_index] - - action = self._action_sequence[:, self._action_seq_index] - self._action_seq_index += 1 - return action, present_time, current_timestamp + action = self._actions[:, self._action_index] + present_action_timestamp = self._actions_timestamps[self._action_index] + self._action_index += 1 + self._present_timestamp = present_timestamp + self._present_action_timestamp = present_action_timestamp + return action def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" diff --git a/lerobot/common/robot_devices/utils.py b/lerobot/common/robot_devices/utils.py index bcbeb8e02..ea06e2487 100644 --- a/lerobot/common/robot_devices/utils.py +++ b/lerobot/common/robot_devices/utils.py @@ -3,17 +3,24 @@ import time def busy_wait(seconds): + if seconds <= 0: + return + if platform.system() == "Darwin": # On Mac, `time.sleep` is not accurate and we need to use this while loop trick, # but it consumes CPU cycles. # TODO(rcadene): find an alternative: from python 11, time.sleep is precise - end_time = time.perf_counter() + seconds + + start_sleep = time.perf_counter() + time.sleep(seconds / 2) + dt_sleep = time.perf_counter() - start_sleep + + end_time = time.perf_counter() + (seconds - dt_sleep) while time.perf_counter() < end_time: pass else: # On Linux time.sleep is accurate - if seconds > 0: - time.sleep(seconds) + time.sleep(seconds) class RobotDeviceNotConnectedError(Exception): diff --git a/test2.py b/test2.py index eae8210b8..922ed7245 100644 --- a/test2.py +++ b/test2.py @@ -40,17 +40,17 @@ def main(env_name, policy_name, extra_overrides): fps = 30 - for i in range(200): + for i in range(400): start_loop_t = time.perf_counter() - next_action, present_time, action_ts = policy.select_action(obs) + next_action = policy.select_action(obs) # noqa: F841 dt_s = time.perf_counter() - start_loop_t busy_wait(1 / fps - dt_s) dt_s = time.perf_counter() - start_loop_t print( - f"{i=}, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) \t{present_time}\t{action_ts}" + f"{i=}, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) \t{policy._present_timestamp}\t{policy._present_action_timestamp}" ) # , {next_action.mean().item()}") # time.sleep(1/30) # frequency at which we receive a new observation (30 Hz = 0.03 s)