Skip to content

Commit

Permalink
Set a default for num_layers based on the model type - num_layers=2 s…
Browse files Browse the repository at this point in the history
…eems good for the character classifier
  • Loading branch information
AngledLuffa committed Sep 12, 2024
1 parent 7eb3a50 commit b999102
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion stanza/models/mwt_expander.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def build_argparse():

parser.add_argument('--hidden_dim', type=int, default=100)
parser.add_argument('--emb_dim', type=int, default=50)
parser.add_argument('--num_layers', type=int, default=1)
parser.add_argument('--num_layers', type=int, default=None, help='Number of layers in model encoder. Defaults to 1 for seq2seq, 2 for classifier')
parser.add_argument('--emb_dropout', type=float, default=0.5)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--max_dec_len', type=int, default=50)
Expand Down Expand Up @@ -153,6 +153,12 @@ def train(args):
args['vocab_size'] = vocab.size
dev_batch = BinaryDataLoader(dev_doc, args['batch_size'], args, vocab=vocab, evaluation=True)

if args['num_layers'] is None:
if args['force_exact_pieces']:
args['num_layers'] = 2
else:
args['num_layers'] = 1

# train a dictionary-based MWT expander
trainer = Trainer(args=args, vocab=vocab, device=args['device'])
logger.info("Training dictionary-based MWT expander...")
Expand Down

0 comments on commit b999102

Please sign in to comment.