wip: still needs batch logic for act and tdmp

This commit is contained in:
Alexander Soare
2024-03-14 15:22:55 +00:00
parent 8c56770318
commit ba91976944
11 changed files with 240 additions and 100 deletions

View File

@@ -0,0 +1,54 @@
from abc import abstractmethod
from collections import deque
import torch
from torch import Tensor, nn
class AbstractPolicy(nn.Module):
@abstractmethod
def update(self, replay_buffer, step):
"""One step of the policy's learning algorithm."""
pass
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
self.load_state_dict(d)
@abstractmethod
def select_action(self, observation) -> Tensor:
"""Select an action (or trajectory of actions) based on an observation during rollout.
Should return a (batch_size, n_action_steps, *) tensor of actions.
"""
pass
def forward(self, *args, **kwargs):
"""Inference step that makes multi-step policies compatible with their single-step environments.
WARNING: In general, this should not be overriden.
Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
the subclass doesn't have to.
This method effectively wraps the `select_action` method of the subclass. The following assumptions are made:
1. The `select_action` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
the action trajectory horizon and * is the action dimensions.
2. Prior to the `select_action` method being called, theres is an `n_action_steps` instance attribute defined.
"""
n_action_steps_attr = "n_action_steps"
if not hasattr(self, n_action_steps_attr):
raise RuntimeError(f"Underlying policy must have an `{n_action_steps_attr}` attribute")
if not hasattr(self, "_action_queue"):
self._action_queue = deque([], maxlen=getattr(self, n_action_steps_attr))
if len(self._action_queue) == 0:
# Each element in the queue has shape (B, *).
self._action_queue.extend(self.select_action(*args, **kwargs).transpose(0, 1))
return self._action_queue.popleft()