This commit is contained in:
Alexander Soare
2024-04-17 10:50:54 +01:00
parent 18dd8f32cd
commit c50a13ab31
3 changed files with 79 additions and 103 deletions

View File

@@ -103,12 +103,21 @@ class ActionChunkingTransformerConfig:
def __post_init__(self):
"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):
raise ValueError("`vision_backbone` must be one of the ResNet variants.")
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
if self.use_temporal_aggregation:
raise NotImplementedError("Temporal aggregation is not yet implemented.")
if self.n_action_steps > self.chunk_size:
raise ValueError(
"The chunk size is the upper bound for the number of action steps per model invocation."
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.n_obs_steps != 1:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
if self.camera_names != ["top"]:
raise ValueError("For now, `camera_names` can only be ['top']")
raise ValueError(f"For now, `camera_names` can only be ['top']. Got {self.camera_names}.")
if len(set(self.camera_names)) != len(self.camera_names):
raise ValueError(f"`camera_names` should not have any repeated entries. Got {self.camera_names}.")