Clear action queue when environment is reset
This commit is contained in:
@@ -20,8 +20,7 @@ class AbstractPolicy(nn.Module, ABC):
|
||||
"""
|
||||
super().__init__()
|
||||
self.n_action_steps = n_action_steps
|
||||
if n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=n_action_steps)
|
||||
self.clear_action_queue()
|
||||
|
||||
@abstractmethod
|
||||
def update(self, replay_buffer, step):
|
||||
@@ -42,6 +41,11 @@ class AbstractPolicy(nn.Module, ABC):
|
||||
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
|
||||
"""
|
||||
|
||||
def clear_action_queue(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
if self.n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=self.n_action_steps)
|
||||
|
||||
def forward(self, *args, **kwargs) -> Tensor:
|
||||
"""Inference step that makes multi-step policies compatible with their single-step environments.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user