ready for review

This commit is contained in:
Alexander Soare
2024-04-08 13:10:19 +01:00
parent 1bab4a1dd5
commit 863f28ffd8
7 changed files with 92 additions and 168 deletions

View File

@@ -67,7 +67,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
def __init__(self, cfg, device, n_action_steps=1):
"""
TODO(alexander-soare): Add documentation for all parameters.
TODO(alexander-soare): Add documentation for all parameters once we have model configs established.
"""
super().__init__()
if getattr(cfg, "n_obs_steps", 1) != 1:
@@ -109,6 +109,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
)
# Backbone for image feature extraction.
self.image_normalizer = transforms.Normalize(
mean=cfg.image_normalization.mean, std=cfg.image_normalization.std
)
backbone_model = getattr(torchvision.models, cfg.backbone)(
replace_stride_with_dilation=[False, False, cfg.dilation],
pretrained=cfg.pretrained_backbone,
@@ -275,9 +278,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
return info
def forward(self, batch: dict[str, Tensor], return_loss: bool = False):
# TODO(now): Maybe this shouldn't be here?
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
images = normalize(batch["observation.images.top"])
images = self.image_normalizer(batch["observation.images.top"])
if return_loss: # training time
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(