Clear action queue when environment is reset

This commit is contained in:
Alexander Soare
2024-03-20 08:31:06 +00:00
parent c5010fee9a
commit 4f1955edfd
2 changed files with 9 additions and 3 deletions

View File

@@ -20,8 +20,7 @@ class AbstractPolicy(nn.Module, ABC):
"""
super().__init__()
self.n_action_steps = n_action_steps
if n_action_steps is not None:
self._action_queue = deque([], maxlen=n_action_steps)
self.clear_action_queue()
@abstractmethod
def update(self, replay_buffer, step):
@@ -42,6 +41,11 @@ class AbstractPolicy(nn.Module, ABC):
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
"""
def clear_action_queue(self):
"""This should be called whenever the environment is reset."""
if self.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.n_action_steps)
def forward(self, *args, **kwargs) -> Tensor:
"""Inference step that makes multi-step policies compatible with their single-step environments.