This commit is contained in:
Alexander Soare
2024-03-20 09:45:45 +00:00
parent b1ec3da035
commit 5332766a82
4 changed files with 34 additions and 233 deletions

View File

@@ -65,6 +65,7 @@ class AbstractPolicy(nn.Module, ABC):
if self.n_action_steps is None:
return self.select_actions(*args, **kwargs)
if len(self._action_queue) == 0:
# Each element in the queue has shape (B, *).
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
# (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
return self._action_queue.popleft()