forked from tangger/lerobot
Works
This commit is contained in:
@@ -20,10 +20,9 @@ The majority of changes here involve removing unused code, unifying naming, and
|
||||
"""
|
||||
|
||||
import math
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import chain
|
||||
from threading import Thread
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
@@ -38,7 +37,6 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.utils import TemporalQueue
|
||||
|
||||
|
||||
class ACTPolicy(
|
||||
@@ -90,10 +88,9 @@ class ACTPolicy(
|
||||
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
|
||||
|
||||
self.reset()
|
||||
self.thread = None
|
||||
|
||||
# TODO(rcadene): Add delta timestamps in policy
|
||||
FPS = 10 # noqa: N806
|
||||
FPS = 30 # noqa: N806
|
||||
self.delta_timestamps = {
|
||||
"action": [i / FPS for i in range(self.config.n_action_steps)],
|
||||
}
|
||||
@@ -104,12 +101,11 @@ class ACTPolicy(
|
||||
self.temporal_ensembler.reset()
|
||||
else:
|
||||
# TODO(rcadene): set proper maxlen
|
||||
self._obs_queue = TemporalQueue(maxlen=1)
|
||||
self._action_seq_queue = TemporalQueue(maxlen=200)
|
||||
|
||||
self._action_sequence = None
|
||||
self._action_seq_index = 0
|
||||
self._action_seq_timestamp = None
|
||||
self.executor = None
|
||||
self.future = None
|
||||
self._actions = None
|
||||
self._actions_timestamps = None
|
||||
self._action_index = 0
|
||||
|
||||
@torch.no_grad
|
||||
def inference(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
@@ -134,77 +130,50 @@ class ACTPolicy(
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
return actions
|
||||
|
||||
def inference_loop(self):
|
||||
while not self.stop_event.is_set():
|
||||
with self.condition:
|
||||
self.condition.wait()
|
||||
def inference_with_timestamp(self, batch, timestamp):
|
||||
start_t = time.perf_counter()
|
||||
|
||||
start_t = time.perf_counter()
|
||||
actions = self.inference(batch)
|
||||
|
||||
last_observation, last_timestamp = self._obs_queue.get_latest()
|
||||
pred_action_sequence = self.inference(last_observation)
|
||||
self._action_seq_queue.add(pred_action_sequence, last_timestamp)
|
||||
dt_s = time.perf_counter() - start_t
|
||||
print(f"Inference, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) -- {timestamp}")
|
||||
|
||||
dt_s = time.perf_counter() - start_t
|
||||
print(
|
||||
f"Inference, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) -- {last_timestamp}"
|
||||
) # , {next_action.mean().item()}")
|
||||
|
||||
self.new_action_seq_event.set()
|
||||
return actions, timestamp
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
present_time = time.time()
|
||||
self._obs_queue.add(batch, present_time)
|
||||
present_timestamp = time.time()
|
||||
|
||||
if self.thread is None:
|
||||
self.stop_event = threading.Event()
|
||||
self.new_action_seq_event = threading.Event()
|
||||
if self.executor is None:
|
||||
self.executor = ThreadPoolExecutor(max_workers=1)
|
||||
self.future = self.executor.submit(self.inference_with_timestamp, batch, present_timestamp)
|
||||
actions, inference_timestamp = self.future.result()
|
||||
self._actions = actions
|
||||
self._actions_timestamps = inference_timestamp + np.array(self.delta_timestamps["action"])
|
||||
|
||||
self.condition = threading.Condition()
|
||||
self.thread = Thread(target=self.inference_loop, args=())
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
if self._action_index == 90:
|
||||
self.future = self.executor.submit(self.inference_with_timestamp, batch, present_timestamp)
|
||||
|
||||
# Ask thread to run first inference
|
||||
with self.condition:
|
||||
self.condition.notify()
|
||||
if self._action_index >= self._actions.shape[1]:
|
||||
actions, inference_timestamp = self.future.result()
|
||||
|
||||
# Block main process until the thread ran it's first inference
|
||||
self.new_action_seq_event.wait()
|
||||
self._action_sequence, self._action_seq_timestamp = self._action_seq_queue.get_latest()
|
||||
|
||||
if self._action_seq_index == 97:
|
||||
with self.condition:
|
||||
self.condition.notify()
|
||||
|
||||
if self._action_seq_index >= len(self._action_sequence):
|
||||
self.new_action_seq_event.wait()
|
||||
latest_action_sequence, latest_seq_timestamp = self._action_seq_queue.get_latest()
|
||||
|
||||
# update sequence index
|
||||
seq_timestamps = self._action_seq_timestamp + np.array(self.delta_timestamps["action"])
|
||||
if self._action_seq_index == len(self.delta_timestamps["action"]):
|
||||
current_timestamp = seq_timestamps[-1]
|
||||
else:
|
||||
current_timestamp = seq_timestamps[self._action_seq_index]
|
||||
|
||||
latest_seq_timestamps = latest_seq_timestamp + np.array(self.delta_timestamps["action"])
|
||||
distances = np.abs(latest_seq_timestamps - current_timestamp)
|
||||
# find corresponding action_index in new actions
|
||||
present_action_timestamp = self._actions_timestamps[-1]
|
||||
new_actions_timestamps = inference_timestamp + np.array(self.delta_timestamps["action"])
|
||||
distances = np.abs(new_actions_timestamps - present_action_timestamp)
|
||||
nearest_idx = distances.argmin()
|
||||
|
||||
# update
|
||||
self._action_index = nearest_idx
|
||||
self._actions_timestamps = new_actions_timestamps
|
||||
self._actions = actions
|
||||
# TODO(rcadene): handle edge cases
|
||||
self._action_seq_index = nearest_idx
|
||||
|
||||
# update action sequence
|
||||
self._action_sequence = latest_action_sequence
|
||||
# update inference timestamp (when this action sequence has been computed)
|
||||
self._action_seq_timestamp = latest_seq_timestamp
|
||||
|
||||
seq_timestamps = self._action_seq_timestamp + np.array(self.delta_timestamps["action"])
|
||||
current_timestamp = seq_timestamps[self._action_seq_index]
|
||||
|
||||
action = self._action_sequence[:, self._action_seq_index]
|
||||
self._action_seq_index += 1
|
||||
return action, present_time, current_timestamp
|
||||
action = self._actions[:, self._action_index]
|
||||
present_action_timestamp = self._actions_timestamps[self._action_index]
|
||||
self._action_index += 1
|
||||
self._present_timestamp = present_timestamp
|
||||
self._present_action_timestamp = present_action_timestamp
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
|
||||
@@ -3,17 +3,24 @@ import time
|
||||
|
||||
|
||||
def busy_wait(seconds):
|
||||
if seconds <= 0:
|
||||
return
|
||||
|
||||
if platform.system() == "Darwin":
|
||||
# On Mac, `time.sleep` is not accurate and we need to use this while loop trick,
|
||||
# but it consumes CPU cycles.
|
||||
# TODO(rcadene): find an alternative: from python 11, time.sleep is precise
|
||||
end_time = time.perf_counter() + seconds
|
||||
|
||||
start_sleep = time.perf_counter()
|
||||
time.sleep(seconds / 2)
|
||||
dt_sleep = time.perf_counter() - start_sleep
|
||||
|
||||
end_time = time.perf_counter() + (seconds - dt_sleep)
|
||||
while time.perf_counter() < end_time:
|
||||
pass
|
||||
else:
|
||||
# On Linux time.sleep is accurate
|
||||
if seconds > 0:
|
||||
time.sleep(seconds)
|
||||
time.sleep(seconds)
|
||||
|
||||
|
||||
class RobotDeviceNotConnectedError(Exception):
|
||||
|
||||
6
test2.py
6
test2.py
@@ -40,17 +40,17 @@ def main(env_name, policy_name, extra_overrides):
|
||||
|
||||
fps = 30
|
||||
|
||||
for i in range(200):
|
||||
for i in range(400):
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
next_action, present_time, action_ts = policy.select_action(obs)
|
||||
next_action = policy.select_action(obs) # noqa: F841
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
print(
|
||||
f"{i=}, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) \t{present_time}\t{action_ts}"
|
||||
f"{i=}, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) \t{policy._present_timestamp}\t{policy._present_action_timestamp}"
|
||||
) # , {next_action.mean().item()}")
|
||||
|
||||
# time.sleep(1/30) # frequency at which we receive a new observation (30 Hz = 0.03 s)
|
||||
|
||||
Reference in New Issue
Block a user