ready for review
This commit is contained in:
@@ -67,7 +67,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
|
||||
def __init__(self, cfg, device, n_action_steps=1):
|
||||
"""
|
||||
TODO(alexander-soare): Add documentation for all parameters.
|
||||
TODO(alexander-soare): Add documentation for all parameters once we have model configs established.
|
||||
"""
|
||||
super().__init__()
|
||||
if getattr(cfg, "n_obs_steps", 1) != 1:
|
||||
@@ -109,6 +109,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
)
|
||||
|
||||
# Backbone for image feature extraction.
|
||||
self.image_normalizer = transforms.Normalize(
|
||||
mean=cfg.image_normalization.mean, std=cfg.image_normalization.std
|
||||
)
|
||||
backbone_model = getattr(torchvision.models, cfg.backbone)(
|
||||
replace_stride_with_dilation=[False, False, cfg.dilation],
|
||||
pretrained=cfg.pretrained_backbone,
|
||||
@@ -275,9 +278,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
return info
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], return_loss: bool = False):
|
||||
# TODO(now): Maybe this shouldn't be here?
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
images = normalize(batch["observation.images.top"])
|
||||
images = self.image_normalizer(batch["observation.images.top"])
|
||||
|
||||
if return_loss: # training time
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(
|
||||
|
||||
Reference in New Issue
Block a user