revert dp changes, make act and tdmpc batch friendly

This commit is contained in:
Alexander Soare
2024-03-18 19:18:21 +00:00
parent 09ddd9bf92
commit 88347965c2
8 changed files with 32 additions and 58 deletions

View File

@@ -1,15 +1,20 @@
from abc import abstractmethod
from abc import ABC, abstractmethod
from collections import deque
import torch
from torch import Tensor, nn
class AbstractPolicy(nn.Module):
class AbstractPolicy(nn.Module, ABC):
"""Base policy which all policies should be derived from.
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
documentation for more information.
"""
@abstractmethod
def update(self, replay_buffer, step):
"""One step of the policy's learning algorithm."""
pass
def save(self, fp):
torch.save(self.state_dict(), fp)
@@ -24,7 +29,6 @@ class AbstractPolicy(nn.Module):
Should return a (batch_size, n_action_steps, *) tensor of actions.
"""
pass
def forward(self, *args, **kwargs):
"""Inference step that makes multi-step policies compatible with their single-step environments.