Skip to content

Commit

Permalink
update seq2seq model.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed May 10, 2022
1 parent 4f1fb5b commit 80ba702
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 5 deletions.
4 changes: 2 additions & 2 deletions examples/seq2seq_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def main():
)
parser.add_argument("--model_dir", default="output/bertseq2seq/", type=str, help="Dir for model save.")
parser.add_argument("--arch",
default="bertseq2seq", type=str,
default="convseq2seq", type=str,
help="The name of the task to train selected in the list: " + ", ".join(
['seq2seq', 'convseq2seq', 'bertseq2seq']),
)
Expand All @@ -52,7 +52,7 @@ def main():
parser.add_argument("--embed_size", default=128, type=int, help="Embedding size.")
parser.add_argument("--hidden_size", default=128, type=int, help="Hidden size.")
parser.add_argument("--dropout", default=0.25, type=float, help="Dropout rate.")
parser.add_argument("--epochs", default=10, type=int, help="Epoch num.")
parser.add_argument("--epochs", default=200, type=int, help="Epoch num.")

args = parser.parse_args()
print(args)
Expand Down
6 changes: 5 additions & 1 deletion pycorrector/seq2seq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ tensorboardX

## Demo

- bertseq2seq demo
- convseq2seq demo

示例[seq2seq_demo.py](../../examples/seq2seq_demo.py)
```
Expand Down Expand Up @@ -93,3 +93,7 @@ predict: 王天华开心地一直说话。
```
python preprocess.py
```

### release models

基于SIGHAN2015数据集训练的seq2seq和convseq2seq模型,已经release到github,通过[github models]()获取。
2 changes: 1 addition & 1 deletion pycorrector/seq2seq/seq2seq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def train(

if args.model_name and os.path.exists(args.model_name):
try:
# set global_step to gobal_step of last saved checkpoint from model path
# set global_step to global_step of last saved checkpoint from model path
checkpoint_suffix = args.model_name.split("/")[-1].split("-")
if len(checkpoint_suffix) > 2:
checkpoint_suffix = checkpoint_suffix[1]
Expand Down
2 changes: 1 addition & 1 deletion pycorrector/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
@author:XuMing([email protected])
@description: version
"""
__version__ = '0.4.4'
__version__ = '0.4.5'

0 comments on commit 80ba702

Please sign in to comment.