backup wip
This commit is contained in:
@@ -26,10 +26,8 @@ class Transformer(nn.Module):
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
normalize_before=False,
|
||||
return_intermediate_dec=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
encoder_layer = TransformerEncoderLayer(
|
||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
@@ -40,9 +38,7 @@ class Transformer(nn.Module):
|
||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
decoder_norm = nn.LayerNorm(d_model)
|
||||
self.decoder = TransformerDecoder(
|
||||
decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec
|
||||
)
|
||||
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
@@ -57,7 +53,6 @@ class Transformer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
mask,
|
||||
query_embed,
|
||||
pos_embed,
|
||||
latent_input=None,
|
||||
@@ -68,10 +63,10 @@ class Transformer(nn.Module):
|
||||
if len(src.shape) == 4: # has H and W
|
||||
# flatten NxCxHxW to HWxNxC
|
||||
bs, c, h, w = src.shape
|
||||
# Each "pixel" on the feature maps will form a token.
|
||||
src = src.flatten(2).permute(2, 0, 1)
|
||||
pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
# mask = mask.flatten(1)
|
||||
|
||||
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
|
||||
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
|
||||
@@ -87,9 +82,9 @@ class Transformer(nn.Module):
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
|
||||
tgt = torch.zeros_like(query_embed)
|
||||
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
||||
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed)
|
||||
hs = hs.transpose(1, 2)
|
||||
memory = self.encoder(src, pos=pos_embed)
|
||||
hs = self.decoder(tgt, memory, pos=pos_embed, query_pos=query_embed)
|
||||
hs = hs.transpose(0, 1)
|
||||
return hs
|
||||
|
||||
|
||||
@@ -103,14 +98,12 @@ class TransformerEncoder(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
output = src
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos)
|
||||
output = layer(output, pos=pos)
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
@@ -119,52 +112,33 @@ class TransformerEncoder(nn.Module):
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
||||
def __init__(self, decoder_layer, num_layers, norm=None):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(decoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
self.return_intermediate = return_intermediate
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
output = tgt
|
||||
|
||||
intermediate = []
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(
|
||||
output,
|
||||
memory,
|
||||
tgt_mask=tgt_mask,
|
||||
memory_mask=memory_mask,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
pos=pos,
|
||||
query_pos=query_pos,
|
||||
)
|
||||
if self.return_intermediate:
|
||||
intermediate.append(self.norm(output))
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
if self.return_intermediate:
|
||||
intermediate.pop()
|
||||
intermediate.append(output)
|
||||
|
||||
if self.return_intermediate:
|
||||
return torch.stack(intermediate)
|
||||
|
||||
return output.unsqueeze(0)
|
||||
return output
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
@@ -192,12 +166,10 @@ class TransformerEncoderLayer(nn.Module):
|
||||
def forward_post(
|
||||
self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
q = k = self.with_pos_embed(src, pos)
|
||||
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
|
||||
src2 = self.self_attn(q, k, value=src)[0]
|
||||
src = src + self.dropout1(src2)
|
||||
src = self.norm1(src)
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
||||
@@ -208,13 +180,11 @@ class TransformerEncoderLayer(nn.Module):
|
||||
def forward_pre(
|
||||
self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
src2 = self.norm1(src)
|
||||
q = k = self.with_pos_embed(src2, pos)
|
||||
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
|
||||
src2 = self.self_attn(q, k, value=src2)[0]
|
||||
src = src + self.dropout1(src2)
|
||||
src2 = self.norm2(src)
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
||||
@@ -224,13 +194,11 @@ class TransformerEncoderLayer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
if self.normalize_before:
|
||||
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
||||
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
||||
return self.forward_pre(src, pos)
|
||||
return self.forward_post(src, pos)
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
@@ -262,23 +230,17 @@ class TransformerDecoderLayer(nn.Module):
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
q = k = self.with_pos_embed(tgt, query_pos)
|
||||
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
|
||||
tgt2 = self.self_attn(q, k, value=tgt)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt = self.norm1(tgt)
|
||||
tgt2 = self.multihead_attn(
|
||||
query=self.with_pos_embed(tgt, query_pos),
|
||||
key=self.with_pos_embed(memory, pos),
|
||||
value=memory,
|
||||
attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask,
|
||||
)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt = self.norm2(tgt)
|
||||
@@ -291,24 +253,18 @@ class TransformerDecoderLayer(nn.Module):
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
|
||||
tgt2 = self.self_attn(q, k, value=tgt2)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt2 = self.norm2(tgt)
|
||||
tgt2 = self.multihead_attn(
|
||||
query=self.with_pos_embed(tgt2, query_pos),
|
||||
key=self.with_pos_embed(memory, pos),
|
||||
value=memory,
|
||||
attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask,
|
||||
)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt2 = self.norm3(tgt)
|
||||
@@ -320,10 +276,6 @@ class TransformerDecoderLayer(nn.Module):
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
@@ -331,16 +283,10 @@ class TransformerDecoderLayer(nn.Module):
|
||||
return self.forward_pre(
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask,
|
||||
memory_mask,
|
||||
tgt_key_padding_mask,
|
||||
memory_key_padding_mask,
|
||||
pos,
|
||||
query_pos,
|
||||
)
|
||||
return self.forward_post(
|
||||
tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
|
||||
)
|
||||
return self.forward_post(tgt, memory, pos, query_pos)
|
||||
|
||||
|
||||
def _get_clones(module, n):
|
||||
@@ -356,7 +302,6 @@ def build_transformer(args):
|
||||
num_encoder_layers=args.enc_layers,
|
||||
num_decoder_layers=args.dec_layers,
|
||||
normalize_before=args.pre_norm,
|
||||
return_intermediate_dec=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user