Fix policy defaults (#113)

This commit is contained in:
Alexander Soare
2024-04-29 08:26:59 +01:00
committed by GitHub
parent 791506dfb8
commit ccffa9e406
3 changed files with 13 additions and 2 deletions

View File

@@ -77,7 +77,7 @@ class ActionChunkingTransformerConfig:
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.images.top": "mean_std",
"observation.state": "mean_std",
}
)

View File

@@ -43,7 +43,7 @@ class DiffusionPolicy(nn.Module):
name = "diffusion"
def __init__(
self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None
self, cfg: DiffusionConfig | None = None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None
):
"""
Args: