diff --git a/word_language_model/model.py b/word_language_model/model.py index 8023ac430f..1b972abb91 100644 --- a/word_language_model/model.py +++ b/word_language_model/model.py @@ -104,22 +104,16 @@ def forward(self, x): x = x + self.pe[:x.size(0), :] return self.dropout(x) -class TransformerModel(nn.Module): +class TransformerModel(nn.Transformer): """Container module with an encoder, a recurrent or transformer module, and a decoder.""" def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5): - super(TransformerModel, self).__init__() - try: - from torch.nn import TransformerEncoder, TransformerEncoderLayer - except BaseException as e: - raise ImportError('TransformerEncoder module does not exist in PyTorch 1.1 or ' - 'lower.') from e + super(TransformerModel, self).__init__(d_model=ninp, nhead=nhead, dim_feedforward=nhid, num_encoder_layers=nlayers) self.model_type = 'Transformer' self.src_mask = None self.pos_encoder = PositionalEncoding(ninp, dropout) - encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout) - self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) - self.encoder = nn.Embedding(ntoken, ninp) + + self.input_emb = nn.Embedding(ntoken, ninp) self.ninp = ninp self.decoder = nn.Linear(ninp, ntoken) @@ -132,7 +126,7 @@ def _generate_square_subsequent_mask(self, sz): def init_weights(self): initrange = 0.1 - nn.init.uniform_(self.encoder.weight, -initrange, initrange) + nn.init.uniform_(self.input_emb.weight, -initrange, initrange) nn.init.zeros_(self.decoder.bias) nn.init.uniform_(self.decoder.weight, -initrange, initrange) @@ -145,8 +139,8 @@ def forward(self, src, has_mask=True): else: self.src_mask = None - src = self.encoder(src) * math.sqrt(self.ninp) + src = self.input_emb(src) * math.sqrt(self.ninp) src = self.pos_encoder(src) - output = self.transformer_encoder(src, self.src_mask) + output = self.encoder(src, mask=self.src_mask) output = self.decoder(output) return F.log_softmax(output, dim=-1)