revision 1
This commit is contained in:
@@ -65,12 +65,16 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
"ActionChunkingTransformerPolicy does not handle multiple observation steps."
|
||||
)
|
||||
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None):
|
||||
"""
|
||||
TODO(alexander-soare): Add documentation for all parameters once we have model configs established.
|
||||
Args:
|
||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||
configuration class is used.
|
||||
"""
|
||||
super().__init__()
|
||||
if getattr(cfg, "n_obs_steps", 1) != 1:
|
||||
if cfg is None:
|
||||
cfg = ActionChunkingTransformerConfig()
|
||||
if cfg.n_obs_steps != 1:
|
||||
raise ValueError(self._multiple_obs_steps_not_handled_msg)
|
||||
self.cfg = cfg
|
||||
|
||||
|
||||
Reference in New Issue
Block a user