Updated version with a queue of action sequences, instead of queue of action
This commit is contained in:
@@ -94,7 +94,9 @@ class ACTPolicy(
|
||||
|
||||
# TODO(rcadene): Add delta timestamps in policy
|
||||
FPS = 10 # noqa: N806
|
||||
self.delta_timestamps = [i / FPS for i in range(self.config.n_action_steps)]
|
||||
self.delta_timestamps = {
|
||||
"action": [i / FPS for i in range(self.config.n_action_steps)],
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
@@ -103,7 +105,11 @@ class ACTPolicy(
|
||||
else:
|
||||
# TODO(rcadene): set proper maxlen
|
||||
self._obs_queue = TemporalQueue(maxlen=1)
|
||||
self._action_queue = TemporalQueue(maxlen=200)
|
||||
self._action_seq_queue = TemporalQueue(maxlen=200)
|
||||
|
||||
self._action_sequence = None
|
||||
self._action_seq_index = 0
|
||||
self._action_seq_timestamp = None
|
||||
|
||||
@torch.no_grad
|
||||
def inference(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
@@ -135,13 +141,20 @@ class ACTPolicy(
|
||||
|
||||
if prev_timestamp is not None and prev_timestamp == last_timestamp:
|
||||
# in case inference ran faster than recording/adding a new observation in the queue
|
||||
# print("WAIT INFERENCE")
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
start_t = time.perf_counter()
|
||||
|
||||
pred_action_sequence = self.inference(last_observation)
|
||||
|
||||
for action, delta_ts in zip(pred_action_sequence, self.delta_timestamps, strict=False):
|
||||
self._action_queue.add(action, last_timestamp + delta_ts)
|
||||
dt_s = time.perf_counter() - start_t
|
||||
print(
|
||||
f"Inference, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) -- {last_timestamp}"
|
||||
) # , {next_action.mean().item()}")
|
||||
|
||||
self._action_seq_queue.add(pred_action_sequence, last_timestamp)
|
||||
|
||||
prev_timestamp = last_timestamp
|
||||
|
||||
@@ -155,14 +168,47 @@ class ACTPolicy(
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
next_action = None
|
||||
while next_action is None:
|
||||
try:
|
||||
next_action = self._action_queue.get(present_time)
|
||||
except ValueError:
|
||||
time.sleep(0.1) # no action available at this present time, we wait a bit
|
||||
while len(self._action_seq_queue) == 0:
|
||||
# print("WAIT")
|
||||
time.sleep(0.1) # no action available at this present time, we wait a bit
|
||||
|
||||
return next_action
|
||||
latest_action_sequence, latest_seq_timestamp = self._action_seq_queue.get_latest()
|
||||
|
||||
if self._action_seq_index == len(self.delta_timestamps["action"]):
|
||||
while self._action_seq_timestamp == latest_seq_timestamp:
|
||||
latest_action_sequence, latest_seq_timestamp = self._action_seq_queue.get_latest()
|
||||
# print("WAIT")
|
||||
time.sleep(0.1)
|
||||
|
||||
if self._action_seq_timestamp is None:
|
||||
self._action_sequence = latest_action_sequence
|
||||
self._action_seq_timestamp = latest_seq_timestamp
|
||||
|
||||
elif self._action_seq_index == 100 and self._action_seq_timestamp < latest_seq_timestamp:
|
||||
# update sequence index
|
||||
seq_timestamps = self._action_seq_timestamp + np.array(self.delta_timestamps["action"])
|
||||
if self._action_seq_index == len(self.delta_timestamps["action"]):
|
||||
current_timestamp = seq_timestamps[-1]
|
||||
else:
|
||||
current_timestamp = seq_timestamps[self._action_seq_index]
|
||||
|
||||
latest_seq_timestamps = latest_seq_timestamp + np.array(self.delta_timestamps["action"])
|
||||
distances = np.abs(latest_seq_timestamps - current_timestamp)
|
||||
nearest_idx = distances.argmin()
|
||||
# TODO(rcadene): handle edge cases
|
||||
self._action_seq_index = nearest_idx
|
||||
|
||||
# update action sequence
|
||||
self._action_sequence = latest_action_sequence
|
||||
# update inference timestamp (when this action sequence has been computed)
|
||||
self._action_seq_timestamp = latest_seq_timestamp
|
||||
|
||||
seq_timestamps = self._action_seq_timestamp + np.array(self.delta_timestamps["action"])
|
||||
current_timestamp = seq_timestamps[self._action_seq_index]
|
||||
|
||||
action = self._action_sequence[:, self._action_seq_index]
|
||||
self._action_seq_index += 1
|
||||
return action, present_time, current_timestamp
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
|
||||
Reference in New Issue
Block a user