backup wip

This commit is contained in:
Alexander Soare
2024-04-03 14:21:07 +01:00
parent c7d70a8db9
commit 110ac5ffa1
6 changed files with 182 additions and 191 deletions

View File

@@ -26,10 +26,8 @@ class Transformer(nn.Module):
dropout=0.1,
activation="relu",
normalize_before=False,
return_intermediate_dec=False,
):
super().__init__()
encoder_layer = TransformerEncoderLayer(
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
)
@@ -40,9 +38,7 @@ class Transformer(nn.Module):
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
)
decoder_norm = nn.LayerNorm(d_model)
self.decoder = TransformerDecoder(
decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec
)
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
self._reset_parameters()
@@ -57,7 +53,6 @@ class Transformer(nn.Module):
def forward(
self,
src,
mask,
query_embed,
pos_embed,
latent_input=None,
@@ -68,10 +63,10 @@ class Transformer(nn.Module):
if len(src.shape) == 4: # has H and W
# flatten NxCxHxW to HWxNxC
bs, c, h, w = src.shape
# Each "pixel" on the feature maps will form a token.
src = src.flatten(2).permute(2, 0, 1)
pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
# mask = mask.flatten(1)
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
@@ -87,9 +82,9 @@ class Transformer(nn.Module):
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
tgt = torch.zeros_like(query_embed)
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed)
hs = hs.transpose(1, 2)
memory = self.encoder(src, pos=pos_embed)
hs = self.decoder(tgt, memory, pos=pos_embed, query_pos=query_embed)
hs = hs.transpose(0, 1)
return hs
@@ -103,14 +98,12 @@ class TransformerEncoder(nn.Module):
def forward(
self,
src,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
output = src
for layer in self.layers:
output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos)
output = layer(output, pos=pos)
if self.norm is not None:
output = self.norm(output)
@@ -119,52 +112,33 @@ class TransformerEncoder(nn.Module):
class TransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
def __init__(self, decoder_layer, num_layers, norm=None):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
def forward(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
output = tgt
intermediate = []
for layer in self.layers:
output = layer(
output,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos,
query_pos=query_pos,
)
if self.return_intermediate:
intermediate.append(self.norm(output))
if self.norm is not None:
output = self.norm(output)
if self.return_intermediate:
intermediate.pop()
intermediate.append(output)
if self.return_intermediate:
return torch.stack(intermediate)
return output.unsqueeze(0)
return output
class TransformerEncoderLayer(nn.Module):
@@ -192,12 +166,10 @@ class TransformerEncoderLayer(nn.Module):
def forward_post(
self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
q = k = self.with_pos_embed(src, pos)
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
src2 = self.self_attn(q, k, value=src)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
@@ -208,13 +180,11 @@ class TransformerEncoderLayer(nn.Module):
def forward_pre(
self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
src2 = self.norm1(src)
q = k = self.with_pos_embed(src2, pos)
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
src2 = self.self_attn(q, k, value=src2)[0]
src = src + self.dropout1(src2)
src2 = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
@@ -224,13 +194,11 @@ class TransformerEncoderLayer(nn.Module):
def forward(
self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
if self.normalize_before:
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
return self.forward_pre(src, pos)
return self.forward_post(src, pos)
class TransformerDecoderLayer(nn.Module):
@@ -262,23 +230,17 @@ class TransformerDecoderLayer(nn.Module):
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
tgt2 = self.self_attn(q, k, value=tgt)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
tgt2 = self.multihead_attn(
query=self.with_pos_embed(tgt, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
@@ -291,24 +253,18 @@ class TransformerDecoderLayer(nn.Module):
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
tgt2 = self.self_attn(q, k, value=tgt2)[0]
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(
query=self.with_pos_embed(tgt2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)[0]
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
@@ -320,10 +276,6 @@ class TransformerDecoderLayer(nn.Module):
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
@@ -331,16 +283,10 @@ class TransformerDecoderLayer(nn.Module):
return self.forward_pre(
tgt,
memory,
tgt_mask,
memory_mask,
tgt_key_padding_mask,
memory_key_padding_mask,
pos,
query_pos,
)
return self.forward_post(
tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
)
return self.forward_post(tgt, memory, pos, query_pos)
def _get_clones(module, n):
@@ -356,7 +302,6 @@ def build_transformer(args):
num_encoder_layers=args.enc_layers,
num_decoder_layers=args.dec_layers,
normalize_before=args.pre_norm,
return_intermediate_dec=True,
)