remove abstracmethods, fix online training
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import abc
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
@@ -44,26 +43,20 @@ class AbstractEnv(EnvBase):
|
||||
raise NotImplementedError()
|
||||
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
|
||||
|
||||
@abc.abstractmethod
|
||||
def render(self, mode="rgb_array", width=640, height=480):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _step(self, tensordict: TensorDict):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _make_env(self):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _make_spec(self):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _set_seed(self, seed: Optional[int]):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class AbstractPolicy(nn.Module, ABC):
|
||||
class AbstractPolicy(nn.Module):
|
||||
"""Base policy which all policies should be derived from.
|
||||
|
||||
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
|
||||
@@ -22,9 +21,9 @@ class AbstractPolicy(nn.Module, ABC):
|
||||
self.n_action_steps = n_action_steps
|
||||
self.clear_action_queue()
|
||||
|
||||
@abstractmethod
|
||||
def update(self, replay_buffer, step):
|
||||
"""One step of the policy's learning algorithm."""
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
@@ -33,13 +32,13 @@ class AbstractPolicy(nn.Module, ABC):
|
||||
d = torch.load(fp)
|
||||
self.load_state_dict(d)
|
||||
|
||||
@abstractmethod
|
||||
def select_actions(self, observation) -> Tensor:
|
||||
"""Select an action (or trajectory of actions) based on an observation during rollout.
|
||||
|
||||
If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of
|
||||
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
|
||||
"""
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
def clear_action_queue(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
|
||||
Reference in New Issue
Block a user