Refactor TD-MPC (#103)

Co-authored-by: Cadene <re.cadene@gmail.com>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Alexander Soare
2024-05-01 16:40:04 +01:00
committed by GitHub
parent a4891095e4
commit d1855a202a
17 changed files with 1105 additions and 1205 deletions

View File

@@ -21,6 +21,14 @@ class Policy(Protocol):
name: str
def __init__(self, cfg, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
"""
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
configuration class is used.
dataset_stats: Dataset statistics to be used for normalization.
"""
def reset(self):
"""To be called whenever the environment is reset.
@@ -39,3 +47,13 @@ class Policy(Protocol):
When the model uses a history of observations, or outputs a sequence of actions, this method deals
with caching.
"""
@runtime_checkable
class PolicyWithUpdate(Policy, Protocol):
def update(self):
"""An update method that is to be called after a training optimization step.
Implements an additional updates the model parameters may need (for example, doing an EMA step for a
target model, or incrementing an internal buffer).
"""