backup wip
This commit is contained in:
@@ -56,7 +56,7 @@ class ActionChunkingTransformerConfig:
|
||||
|
||||
# Inputs / output structure.
|
||||
n_obs_steps: int = 1
|
||||
camera_names: list[str] = field(default_factory=lambda: ["top"])
|
||||
camera_names: tuple[str] = ("top",)
|
||||
chunk_size: int = 100
|
||||
n_action_steps: int = 100
|
||||
|
||||
@@ -101,7 +101,7 @@ class ActionChunkingTransformerConfig:
|
||||
utd: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
"""Input validation."""
|
||||
"""Input validation (not exhaustive)."""
|
||||
if not self.vision_backbone.startswith("resnet"):
|
||||
raise ValueError("`vision_backbone` must be one of the ResNet variants.")
|
||||
if self.use_temporal_aggregation:
|
||||
|
||||
@@ -163,7 +163,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||
"""
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
queue is empty.
|
||||
|
||||
Reference in New Issue
Block a user