Skip to content

Commit

Permalink
simpler subsequent mask generator (#1198)
Browse files Browse the repository at this point in the history
  • Loading branch information
hoosierEE authored Nov 27, 2023
1 parent c4dc481 commit c0b889d
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions word_language_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,7 @@ def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
self.init_weights()

def _generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
return torch.log(torch.tril(torch.ones(sz,sz)))

def init_weights(self):
initrange = 0.1
Expand Down

0 comments on commit c0b889d

Please sign in to comment.