forked from tangger/lerobot
Works
This commit is contained in:
@@ -20,10 +20,9 @@ The majority of changes here involve removing unused code, unifying naming, and
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from threading import Thread
|
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
@@ -38,7 +37,6 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
|||||||
|
|
||||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.utils import TemporalQueue
|
|
||||||
|
|
||||||
|
|
||||||
class ACTPolicy(
|
class ACTPolicy(
|
||||||
@@ -90,10 +88,9 @@ class ACTPolicy(
|
|||||||
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
|
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
self.thread = None
|
|
||||||
|
|
||||||
# TODO(rcadene): Add delta timestamps in policy
|
# TODO(rcadene): Add delta timestamps in policy
|
||||||
FPS = 10 # noqa: N806
|
FPS = 30 # noqa: N806
|
||||||
self.delta_timestamps = {
|
self.delta_timestamps = {
|
||||||
"action": [i / FPS for i in range(self.config.n_action_steps)],
|
"action": [i / FPS for i in range(self.config.n_action_steps)],
|
||||||
}
|
}
|
||||||
@@ -104,12 +101,11 @@ class ACTPolicy(
|
|||||||
self.temporal_ensembler.reset()
|
self.temporal_ensembler.reset()
|
||||||
else:
|
else:
|
||||||
# TODO(rcadene): set proper maxlen
|
# TODO(rcadene): set proper maxlen
|
||||||
self._obs_queue = TemporalQueue(maxlen=1)
|
self.executor = None
|
||||||
self._action_seq_queue = TemporalQueue(maxlen=200)
|
self.future = None
|
||||||
|
self._actions = None
|
||||||
self._action_sequence = None
|
self._actions_timestamps = None
|
||||||
self._action_seq_index = 0
|
self._action_index = 0
|
||||||
self._action_seq_timestamp = None
|
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def inference(self, batch: dict[str, Tensor]) -> Tensor:
|
def inference(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
@@ -134,77 +130,50 @@ class ACTPolicy(
|
|||||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
def inference_loop(self):
|
def inference_with_timestamp(self, batch, timestamp):
|
||||||
while not self.stop_event.is_set():
|
start_t = time.perf_counter()
|
||||||
with self.condition:
|
|
||||||
self.condition.wait()
|
|
||||||
|
|
||||||
start_t = time.perf_counter()
|
actions = self.inference(batch)
|
||||||
|
|
||||||
last_observation, last_timestamp = self._obs_queue.get_latest()
|
dt_s = time.perf_counter() - start_t
|
||||||
pred_action_sequence = self.inference(last_observation)
|
print(f"Inference, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) -- {timestamp}")
|
||||||
self._action_seq_queue.add(pred_action_sequence, last_timestamp)
|
|
||||||
|
|
||||||
dt_s = time.perf_counter() - start_t
|
return actions, timestamp
|
||||||
print(
|
|
||||||
f"Inference, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) -- {last_timestamp}"
|
|
||||||
) # , {next_action.mean().item()}")
|
|
||||||
|
|
||||||
self.new_action_seq_event.set()
|
|
||||||
|
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
present_time = time.time()
|
present_timestamp = time.time()
|
||||||
self._obs_queue.add(batch, present_time)
|
|
||||||
|
|
||||||
if self.thread is None:
|
if self.executor is None:
|
||||||
self.stop_event = threading.Event()
|
self.executor = ThreadPoolExecutor(max_workers=1)
|
||||||
self.new_action_seq_event = threading.Event()
|
self.future = self.executor.submit(self.inference_with_timestamp, batch, present_timestamp)
|
||||||
|
actions, inference_timestamp = self.future.result()
|
||||||
|
self._actions = actions
|
||||||
|
self._actions_timestamps = inference_timestamp + np.array(self.delta_timestamps["action"])
|
||||||
|
|
||||||
self.condition = threading.Condition()
|
if self._action_index == 90:
|
||||||
self.thread = Thread(target=self.inference_loop, args=())
|
self.future = self.executor.submit(self.inference_with_timestamp, batch, present_timestamp)
|
||||||
self.thread.daemon = True
|
|
||||||
self.thread.start()
|
|
||||||
|
|
||||||
# Ask thread to run first inference
|
if self._action_index >= self._actions.shape[1]:
|
||||||
with self.condition:
|
actions, inference_timestamp = self.future.result()
|
||||||
self.condition.notify()
|
|
||||||
|
|
||||||
# Block main process until the thread ran it's first inference
|
# find corresponding action_index in new actions
|
||||||
self.new_action_seq_event.wait()
|
present_action_timestamp = self._actions_timestamps[-1]
|
||||||
self._action_sequence, self._action_seq_timestamp = self._action_seq_queue.get_latest()
|
new_actions_timestamps = inference_timestamp + np.array(self.delta_timestamps["action"])
|
||||||
|
distances = np.abs(new_actions_timestamps - present_action_timestamp)
|
||||||
if self._action_seq_index == 97:
|
|
||||||
with self.condition:
|
|
||||||
self.condition.notify()
|
|
||||||
|
|
||||||
if self._action_seq_index >= len(self._action_sequence):
|
|
||||||
self.new_action_seq_event.wait()
|
|
||||||
latest_action_sequence, latest_seq_timestamp = self._action_seq_queue.get_latest()
|
|
||||||
|
|
||||||
# update sequence index
|
|
||||||
seq_timestamps = self._action_seq_timestamp + np.array(self.delta_timestamps["action"])
|
|
||||||
if self._action_seq_index == len(self.delta_timestamps["action"]):
|
|
||||||
current_timestamp = seq_timestamps[-1]
|
|
||||||
else:
|
|
||||||
current_timestamp = seq_timestamps[self._action_seq_index]
|
|
||||||
|
|
||||||
latest_seq_timestamps = latest_seq_timestamp + np.array(self.delta_timestamps["action"])
|
|
||||||
distances = np.abs(latest_seq_timestamps - current_timestamp)
|
|
||||||
nearest_idx = distances.argmin()
|
nearest_idx = distances.argmin()
|
||||||
|
|
||||||
|
# update
|
||||||
|
self._action_index = nearest_idx
|
||||||
|
self._actions_timestamps = new_actions_timestamps
|
||||||
|
self._actions = actions
|
||||||
# TODO(rcadene): handle edge cases
|
# TODO(rcadene): handle edge cases
|
||||||
self._action_seq_index = nearest_idx
|
|
||||||
|
|
||||||
# update action sequence
|
action = self._actions[:, self._action_index]
|
||||||
self._action_sequence = latest_action_sequence
|
present_action_timestamp = self._actions_timestamps[self._action_index]
|
||||||
# update inference timestamp (when this action sequence has been computed)
|
self._action_index += 1
|
||||||
self._action_seq_timestamp = latest_seq_timestamp
|
self._present_timestamp = present_timestamp
|
||||||
|
self._present_action_timestamp = present_action_timestamp
|
||||||
seq_timestamps = self._action_seq_timestamp + np.array(self.delta_timestamps["action"])
|
return action
|
||||||
current_timestamp = seq_timestamps[self._action_seq_index]
|
|
||||||
|
|
||||||
action = self._action_sequence[:, self._action_seq_index]
|
|
||||||
self._action_seq_index += 1
|
|
||||||
return action, present_time, current_timestamp
|
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
"""Run the batch through the model and compute the loss for training or validation."""
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
|
|||||||
@@ -3,17 +3,24 @@ import time
|
|||||||
|
|
||||||
|
|
||||||
def busy_wait(seconds):
|
def busy_wait(seconds):
|
||||||
|
if seconds <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
if platform.system() == "Darwin":
|
if platform.system() == "Darwin":
|
||||||
# On Mac, `time.sleep` is not accurate and we need to use this while loop trick,
|
# On Mac, `time.sleep` is not accurate and we need to use this while loop trick,
|
||||||
# but it consumes CPU cycles.
|
# but it consumes CPU cycles.
|
||||||
# TODO(rcadene): find an alternative: from python 11, time.sleep is precise
|
# TODO(rcadene): find an alternative: from python 11, time.sleep is precise
|
||||||
end_time = time.perf_counter() + seconds
|
|
||||||
|
start_sleep = time.perf_counter()
|
||||||
|
time.sleep(seconds / 2)
|
||||||
|
dt_sleep = time.perf_counter() - start_sleep
|
||||||
|
|
||||||
|
end_time = time.perf_counter() + (seconds - dt_sleep)
|
||||||
while time.perf_counter() < end_time:
|
while time.perf_counter() < end_time:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# On Linux time.sleep is accurate
|
# On Linux time.sleep is accurate
|
||||||
if seconds > 0:
|
time.sleep(seconds)
|
||||||
time.sleep(seconds)
|
|
||||||
|
|
||||||
|
|
||||||
class RobotDeviceNotConnectedError(Exception):
|
class RobotDeviceNotConnectedError(Exception):
|
||||||
|
|||||||
6
test2.py
6
test2.py
@@ -40,17 +40,17 @@ def main(env_name, policy_name, extra_overrides):
|
|||||||
|
|
||||||
fps = 30
|
fps = 30
|
||||||
|
|
||||||
for i in range(200):
|
for i in range(400):
|
||||||
start_loop_t = time.perf_counter()
|
start_loop_t = time.perf_counter()
|
||||||
|
|
||||||
next_action, present_time, action_ts = policy.select_action(obs)
|
next_action = policy.select_action(obs) # noqa: F841
|
||||||
|
|
||||||
dt_s = time.perf_counter() - start_loop_t
|
dt_s = time.perf_counter() - start_loop_t
|
||||||
busy_wait(1 / fps - dt_s)
|
busy_wait(1 / fps - dt_s)
|
||||||
|
|
||||||
dt_s = time.perf_counter() - start_loop_t
|
dt_s = time.perf_counter() - start_loop_t
|
||||||
print(
|
print(
|
||||||
f"{i=}, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) \t{present_time}\t{action_ts}"
|
f"{i=}, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) \t{policy._present_timestamp}\t{policy._present_action_timestamp}"
|
||||||
) # , {next_action.mean().item()}")
|
) # , {next_action.mean().item()}")
|
||||||
|
|
||||||
# time.sleep(1/30) # frequency at which we receive a new observation (30 Hz = 0.03 s)
|
# time.sleep(1/30) # frequency at which we receive a new observation (30 Hz = 0.03 s)
|
||||||
|
|||||||
Reference in New Issue
Block a user