Merge remote-tracking branch 'upstream/main' into unify_policy_api

This commit is contained in:
Alexander Soare
2024-04-16 17:30:41 +01:00
4 changed files with 23 additions and 17 deletions

View File

@@ -65,12 +65,16 @@ class ActionChunkingTransformerPolicy(nn.Module):
"ActionChunkingTransformerPolicy does not handle multiple observation steps."
)
def __init__(self, cfg: ActionChunkingTransformerConfig):
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None):
"""
TODO(alexander-soare): Add documentation for all parameters once we have model configs established.
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
configuration class is used.
"""
super().__init__()
if getattr(cfg, "n_obs_steps", 1) != 1:
if cfg is None:
cfg = ActionChunkingTransformerConfig()
if cfg.n_obs_steps != 1:
raise ValueError(self._multiple_obs_steps_not_handled_msg)
self.cfg = cfg
@@ -161,8 +165,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
if self.cfg.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.cfg.n_action_steps)
def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
return self.select_action(self, batch)
# def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
# return self.select_action(self, batch)
@torch.no_grad
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor: