This commit is contained in:
Remi Cadene
2024-10-16 12:10:40 +02:00
parent a7841afaa4
commit a1a7f7887f
3 changed files with 52 additions and 76 deletions

View File

@@ -20,10 +20,9 @@ The majority of changes here involve removing unused code, unifying naming, and
"""
import math
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from itertools import chain
from threading import Thread
from typing import Callable
import einops
@@ -38,7 +37,6 @@ from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.utils import TemporalQueue
class ACTPolicy(
@@ -90,10 +88,9 @@ class ACTPolicy(
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
self.reset()
self.thread = None
# TODO(rcadene): Add delta timestamps in policy
FPS = 10 # noqa: N806
FPS = 30 # noqa: N806
self.delta_timestamps = {
"action": [i / FPS for i in range(self.config.n_action_steps)],
}
@@ -104,12 +101,11 @@ class ACTPolicy(
self.temporal_ensembler.reset()
else:
# TODO(rcadene): set proper maxlen
self._obs_queue = TemporalQueue(maxlen=1)
self._action_seq_queue = TemporalQueue(maxlen=200)
self._action_sequence = None
self._action_seq_index = 0
self._action_seq_timestamp = None
self.executor = None
self.future = None
self._actions = None
self._actions_timestamps = None
self._action_index = 0
@torch.no_grad
def inference(self, batch: dict[str, Tensor]) -> Tensor:
@@ -134,77 +130,50 @@ class ACTPolicy(
actions = self.unnormalize_outputs({"action": actions})["action"]
return actions
def inference_loop(self):
while not self.stop_event.is_set():
with self.condition:
self.condition.wait()
def inference_with_timestamp(self, batch, timestamp):
start_t = time.perf_counter()
start_t = time.perf_counter()
actions = self.inference(batch)
last_observation, last_timestamp = self._obs_queue.get_latest()
pred_action_sequence = self.inference(last_observation)
self._action_seq_queue.add(pred_action_sequence, last_timestamp)
dt_s = time.perf_counter() - start_t
print(f"Inference, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) -- {timestamp}")
dt_s = time.perf_counter() - start_t
print(
f"Inference, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) -- {last_timestamp}"
) # , {next_action.mean().item()}")
self.new_action_seq_event.set()
return actions, timestamp
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
present_time = time.time()
self._obs_queue.add(batch, present_time)
present_timestamp = time.time()
if self.thread is None:
self.stop_event = threading.Event()
self.new_action_seq_event = threading.Event()
if self.executor is None:
self.executor = ThreadPoolExecutor(max_workers=1)
self.future = self.executor.submit(self.inference_with_timestamp, batch, present_timestamp)
actions, inference_timestamp = self.future.result()
self._actions = actions
self._actions_timestamps = inference_timestamp + np.array(self.delta_timestamps["action"])
self.condition = threading.Condition()
self.thread = Thread(target=self.inference_loop, args=())
self.thread.daemon = True
self.thread.start()
if self._action_index == 90:
self.future = self.executor.submit(self.inference_with_timestamp, batch, present_timestamp)
# Ask thread to run first inference
with self.condition:
self.condition.notify()
if self._action_index >= self._actions.shape[1]:
actions, inference_timestamp = self.future.result()
# Block main process until the thread ran it's first inference
self.new_action_seq_event.wait()
self._action_sequence, self._action_seq_timestamp = self._action_seq_queue.get_latest()
if self._action_seq_index == 97:
with self.condition:
self.condition.notify()
if self._action_seq_index >= len(self._action_sequence):
self.new_action_seq_event.wait()
latest_action_sequence, latest_seq_timestamp = self._action_seq_queue.get_latest()
# update sequence index
seq_timestamps = self._action_seq_timestamp + np.array(self.delta_timestamps["action"])
if self._action_seq_index == len(self.delta_timestamps["action"]):
current_timestamp = seq_timestamps[-1]
else:
current_timestamp = seq_timestamps[self._action_seq_index]
latest_seq_timestamps = latest_seq_timestamp + np.array(self.delta_timestamps["action"])
distances = np.abs(latest_seq_timestamps - current_timestamp)
# find corresponding action_index in new actions
present_action_timestamp = self._actions_timestamps[-1]
new_actions_timestamps = inference_timestamp + np.array(self.delta_timestamps["action"])
distances = np.abs(new_actions_timestamps - present_action_timestamp)
nearest_idx = distances.argmin()
# update
self._action_index = nearest_idx
self._actions_timestamps = new_actions_timestamps
self._actions = actions
# TODO(rcadene): handle edge cases
self._action_seq_index = nearest_idx
# update action sequence
self._action_sequence = latest_action_sequence
# update inference timestamp (when this action sequence has been computed)
self._action_seq_timestamp = latest_seq_timestamp
seq_timestamps = self._action_seq_timestamp + np.array(self.delta_timestamps["action"])
current_timestamp = seq_timestamps[self._action_seq_index]
action = self._action_sequence[:, self._action_seq_index]
self._action_seq_index += 1
return action, present_time, current_timestamp
action = self._actions[:, self._action_index]
present_action_timestamp = self._actions_timestamps[self._action_index]
self._action_index += 1
self._present_timestamp = present_timestamp
self._present_action_timestamp = present_action_timestamp
return action
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""

View File

@@ -3,17 +3,24 @@ import time
def busy_wait(seconds):
if seconds <= 0:
return
if platform.system() == "Darwin":
# On Mac, `time.sleep` is not accurate and we need to use this while loop trick,
# but it consumes CPU cycles.
# TODO(rcadene): find an alternative: from python 11, time.sleep is precise
end_time = time.perf_counter() + seconds
start_sleep = time.perf_counter()
time.sleep(seconds / 2)
dt_sleep = time.perf_counter() - start_sleep
end_time = time.perf_counter() + (seconds - dt_sleep)
while time.perf_counter() < end_time:
pass
else:
# On Linux time.sleep is accurate
if seconds > 0:
time.sleep(seconds)
time.sleep(seconds)
class RobotDeviceNotConnectedError(Exception):

View File

@@ -40,17 +40,17 @@ def main(env_name, policy_name, extra_overrides):
fps = 30
for i in range(200):
for i in range(400):
start_loop_t = time.perf_counter()
next_action, present_time, action_ts = policy.select_action(obs)
next_action = policy.select_action(obs) # noqa: F841
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - start_loop_t
print(
f"{i=}, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) \t{present_time}\t{action_ts}"
f"{i=}, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) \t{policy._present_timestamp}\t{policy._present_action_timestamp}"
) # , {next_action.mean().item()}")
# time.sleep(1/30) # frequency at which we receive a new observation (30 Hz = 0.03 s)