Updated version with a queue of action sequences, instead of queue of action

This commit is contained in:
Remi Cadene
2024-10-14 12:23:29 +02:00
parent 1f59edd5e7
commit 6afdf2f626
3 changed files with 78 additions and 27 deletions

View File

@@ -94,7 +94,9 @@ class ACTPolicy(
# TODO(rcadene): Add delta timestamps in policy
FPS = 10 # noqa: N806
self.delta_timestamps = [i / FPS for i in range(self.config.n_action_steps)]
self.delta_timestamps = {
"action": [i / FPS for i in range(self.config.n_action_steps)],
}
def reset(self):
"""This should be called whenever the environment is reset."""
@@ -103,7 +105,11 @@ class ACTPolicy(
else:
# TODO(rcadene): set proper maxlen
self._obs_queue = TemporalQueue(maxlen=1)
self._action_queue = TemporalQueue(maxlen=200)
self._action_seq_queue = TemporalQueue(maxlen=200)
self._action_sequence = None
self._action_seq_index = 0
self._action_seq_timestamp = None
@torch.no_grad
def inference(self, batch: dict[str, Tensor]) -> Tensor:
@@ -135,13 +141,20 @@ class ACTPolicy(
if prev_timestamp is not None and prev_timestamp == last_timestamp:
# in case inference ran faster than recording/adding a new observation in the queue
# print("WAIT INFERENCE")
time.sleep(0.1)
continue
start_t = time.perf_counter()
pred_action_sequence = self.inference(last_observation)
for action, delta_ts in zip(pred_action_sequence, self.delta_timestamps, strict=False):
self._action_queue.add(action, last_timestamp + delta_ts)
dt_s = time.perf_counter() - start_t
print(
f"Inference, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) -- {last_timestamp}"
) # , {next_action.mean().item()}")
self._action_seq_queue.add(pred_action_sequence, last_timestamp)
prev_timestamp = last_timestamp
@@ -155,14 +168,47 @@ class ACTPolicy(
self.thread.daemon = True
self.thread.start()
next_action = None
while next_action is None:
try:
next_action = self._action_queue.get(present_time)
except ValueError:
time.sleep(0.1) # no action available at this present time, we wait a bit
while len(self._action_seq_queue) == 0:
# print("WAIT")
time.sleep(0.1) # no action available at this present time, we wait a bit
return next_action
latest_action_sequence, latest_seq_timestamp = self._action_seq_queue.get_latest()
if self._action_seq_index == len(self.delta_timestamps["action"]):
while self._action_seq_timestamp == latest_seq_timestamp:
latest_action_sequence, latest_seq_timestamp = self._action_seq_queue.get_latest()
# print("WAIT")
time.sleep(0.1)
if self._action_seq_timestamp is None:
self._action_sequence = latest_action_sequence
self._action_seq_timestamp = latest_seq_timestamp
elif self._action_seq_index == 100 and self._action_seq_timestamp < latest_seq_timestamp:
# update sequence index
seq_timestamps = self._action_seq_timestamp + np.array(self.delta_timestamps["action"])
if self._action_seq_index == len(self.delta_timestamps["action"]):
current_timestamp = seq_timestamps[-1]
else:
current_timestamp = seq_timestamps[self._action_seq_index]
latest_seq_timestamps = latest_seq_timestamp + np.array(self.delta_timestamps["action"])
distances = np.abs(latest_seq_timestamps - current_timestamp)
nearest_idx = distances.argmin()
# TODO(rcadene): handle edge cases
self._action_seq_index = nearest_idx
# update action sequence
self._action_sequence = latest_action_sequence
# update inference timestamp (when this action sequence has been computed)
self._action_seq_timestamp = latest_seq_timestamp
seq_timestamps = self._action_seq_timestamp + np.array(self.delta_timestamps["action"])
current_timestamp = seq_timestamps[self._action_seq_index]
action = self._action_sequence[:, self._action_seq_index]
self._action_seq_index += 1
return action, present_time, current_timestamp
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""