Skip to content

Commit

Permalink
Update TransformerModel using nn.Transformer module (#1138)
Browse files Browse the repository at this point in the history
  • Loading branch information
tairenpiao authored Aug 8, 2023
1 parent 2095514 commit 001d493
Showing 1 changed file with 7 additions and 13 deletions.
20 changes: 7 additions & 13 deletions word_language_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,22 +104,16 @@ def forward(self, x):
x = x + self.pe[:x.size(0), :]
return self.dropout(x)

class TransformerModel(nn.Module):
class TransformerModel(nn.Transformer):
"""Container module with an encoder, a recurrent or transformer module, and a decoder."""

def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
super(TransformerModel, self).__init__()
try:
from torch.nn import TransformerEncoder, TransformerEncoderLayer
except BaseException as e:
raise ImportError('TransformerEncoder module does not exist in PyTorch 1.1 or '
'lower.') from e
super(TransformerModel, self).__init__(d_model=ninp, nhead=nhead, dim_feedforward=nhid, num_encoder_layers=nlayers)
self.model_type = 'Transformer'
self.src_mask = None
self.pos_encoder = PositionalEncoding(ninp, dropout)
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
self.encoder = nn.Embedding(ntoken, ninp)

self.input_emb = nn.Embedding(ntoken, ninp)
self.ninp = ninp
self.decoder = nn.Linear(ninp, ntoken)

Expand All @@ -132,7 +126,7 @@ def _generate_square_subsequent_mask(self, sz):

def init_weights(self):
initrange = 0.1
nn.init.uniform_(self.encoder.weight, -initrange, initrange)
nn.init.uniform_(self.input_emb.weight, -initrange, initrange)
nn.init.zeros_(self.decoder.bias)
nn.init.uniform_(self.decoder.weight, -initrange, initrange)

Expand All @@ -145,8 +139,8 @@ def forward(self, src, has_mask=True):
else:
self.src_mask = None

src = self.encoder(src) * math.sqrt(self.ninp)
src = self.input_emb(src) * math.sqrt(self.ninp)
src = self.pos_encoder(src)
output = self.transformer_encoder(src, self.src_mask)
output = self.encoder(src, mask=self.src_mask)
output = self.decoder(output)
return F.log_softmax(output, dim=-1)

0 comments on commit 001d493

Please sign in to comment.