某新闻数据集,第一个空格分割标签和文本,多标签再用|
符号分割。形如:
人生-死亡|灾害/意外-坍/垮塌|人生-失联 青岛地铁施工段坍塌致3人遇难2名失联者仍在搜救
├── README.md
├── chinese_roformer-v2-char_L-12_H-768_A-12 12层bert base
│ ├── bert_config.json
│ ├── bert_model.ckpt.data-00000-of-00001
│ ├── bert_model.ckpt.index
│ ├── bert_model.ckpt.meta
│ ├── checkpoint
│ └── vocab.txt
├── chinese_roformer-v2-char_L-6_H-384_A-6 6层bert small
│ ├── bert_config.json
│ ├── bert_model.ckpt.data-00000-of-00001
│ ├── bert_model.ckpt.index
│ ├── bert_model.ckpt.meta
│ ├── checkpoint
│ └── vocab.txt
├── config.py 部分超参数配置文件
├── data 数据集
│ ├── multi-classification-test.txt
│ └── multi-classification-train.txt
├── evaluate.py 模型评估
├── loader.py 数据编码器
├── main.py
├── model.jpg 模型示意图
├── model.py 模型文件
├── nohup.out 训练日志
├── path.py 路径文件
├── predict.py 模型预测
├── train.py 模型训练
├── utils bert4keras工具包,也可pip下载
│ ├── __init__.py
│ ├── adversarial.py
│ ├── backend.py
│ ├── layers.py
│ ├── loss.py
│ ├── models.py
│ ├── optimizers.py
│ ├── snippets.py
│ └── tokenizers.py
└── weights
├── mlb.pkl 所有标签
└── multi-label_roformer_v2_AdamEMA.h5 模型权重
5 directories, 36 files
Keras==2.2.4
matplotlib==3.4.0
pandas==1.2.3
pydot==1.4.1
tensorflow==1.14.0
tqdm==4.61.2
RoFormer-Sim,又称SimBERTv2,是SimBERT模型的升级版。
https://github.com/ZhuiyiTechnology/roformer-sim
基于UniLM思想、融检索与生成于一体的BERT模型。
单标签模型变多标签模型最后一层激活函数应由 softmax 变为 sigmoid
- 使用EMA(exponential mobing average)滑动平均配合Adam作为优化策略。滑动平均可以用来估计变量的局部值,是的变量的更新与一段时间内的历史值有关。它的意义在于利用滑动平均的参数来提高模型在测试数据上的健壮性。 EMA 对每一个待更新训练学习的变量 (variable) 都会维护一个影子变量 (shadow variable)。影子变量的初始值就是这个变量的初始值。
- BERT模型由于已经有了预训练权重,所以微调权重只需要很小的学习率,而Dense层使用的
he_normal
初始化学习率,需要使用较大学习率,所以本模型使用分层学习率 - 在Embedding层注入扰动,对抗训练 ,使模型更具鲁棒性。
187/187 [==============================] - 351s 2s/step - loss: 0.6388 - acc: 0.6758
1498it [00:29, 51.09it/s]
acc: 0.7790 best acc: 0.7790
Epoch 2/999
187/187 [==============================] - 363s 2s/step - loss: 0.0986 - acc: 0.8808
1498it [00:23, 63.30it/s]
acc: 0.8765 best acc: 0.8765
Epoch 3/999
187/187 [==============================] - 360s 2s/step - loss: 0.0729 - acc: 0.9018
1498it [00:27, 55.34it/s]
acc: 0.9005 best acc: 0.9005
... ...
Epoch 21/999
187/187 [==============================] - 359s 2s/step - loss: 0.0163 - acc: 0.9449
1498it [00:26, 57.33it/s]
Early stop count 3/5
acc: 0.9246 best acc: 0.9306
Epoch 22/999
187/187 [==============================] - 358s 2s/step - loss: 0.0172 - acc: 0.9441
1498it [00:26, 56.16it/s]
Early stop count 4/5
acc: 0.9259 best acc: 0.9306
Epoch 23/999
187/187 [==============================] - 352s 2s/step - loss: 0.0156 - acc: 0.9493
1498it [00:27, 54.31it/s]
Early stop count 5/5
acc: 0.9252 best acc: 0.9306
运行evaluate.py
文件
评估F1值,当且仅当单条样本的类别全部预测正确才算True,否则为False。读者可以后期自行修改判断标准,可以增设类别的权重然后算加权。
precision recall f1-score support
交往-会见 0.9231 1.0000 0.9600 12
交往-感谢 1.0000 1.0000 1.0000 8
交往-探班 1.0000 0.9000 0.9474 10
交往-点赞 0.9091 0.9091 0.9091 11
交往-道歉 0.9048 1.0000 0.9500 19
产品行为-上映 0.9706 0.9429 0.9565 35
产品行为-下架 1.0000 1.0000 1.0000 24
产品行为-发布 0.9933 0.9867 0.9900 150
产品行为-召回 1.0000 1.0000 1.0000 36
产品行为-获奖 0.9375 0.9375 0.9375 16
人生-产子/女 0.9375 1.0000 0.9677 15
人生-出轨 0.7500 0.7500 0.7500 4
人生-分手 1.0000 1.0000 1.0000 15
人生-失联 0.9286 0.9286 0.9286 14
人生-婚礼 0.8333 0.8333 0.8333 6
人生-庆生 1.0000 1.0000 1.0000 16
人生-怀孕 1.0000 0.7500 0.8571 8
人生-死亡 0.9495 0.8868 0.9171 106
人生-求婚 1.0000 1.0000 1.0000 9
人生-离婚 0.9706 1.0000 0.9851 33
人生-结婚 0.9500 0.8837 0.9157 43
人生-订婚 0.8182 1.0000 0.9000 9
司法行为-举报 1.0000 1.0000 1.0000 12
司法行为-入狱 0.9000 1.0000 0.9474 18
司法行为-开庭 0.9333 1.0000 0.9655 14
司法行为-拘捕 0.9775 0.9886 0.9831 88
司法行为-立案 1.0000 1.0000 1.0000 9
司法行为-约谈 0.9697 1.0000 0.9846 32
司法行为-罚款 1.0000 1.0000 1.0000 29
司法行为-起诉 0.9500 0.9048 0.9268 21
灾害/意外-地震 1.0000 1.0000 1.0000 14
灾害/意外-坍/垮塌 1.0000 1.0000 1.0000 10
灾害/意外-坠机 1.0000 1.0000 1.0000 13
灾害/意外-洪灾 1.0000 0.8571 0.9231 7
灾害/意外-爆炸 1.0000 0.8889 0.9412 9
灾害/意外-袭击 1.0000 0.8750 0.9333 16
灾害/意外-起火 0.9630 0.9630 0.9630 27
灾害/意外-车祸 0.9444 0.9714 0.9577 35
竞赛行为-夺冠 0.8833 0.9464 0.9138 56
竞赛行为-晋级 0.8919 1.0000 0.9429 33
竞赛行为-禁赛 0.9375 0.9375 0.9375 16
竞赛行为-胜负 0.9769 0.9906 0.9837 213
竞赛行为-退役 0.8462 1.0000 0.9167 11
竞赛行为-退赛 0.8571 1.0000 0.9231 18
组织关系-停职 0.9167 1.0000 0.9565 11
组织关系-加盟 1.0000 0.9268 0.9620 41
组织关系-裁员 1.0000 0.8947 0.9444 19
组织关系-解散 0.9091 1.0000 0.9524 10
组织关系-解约 0.8333 1.0000 0.9091 5
组织关系-解雇 0.8462 0.8462 0.8462 13
组织关系-辞/离职 1.0000 0.9859 0.9929 71
组织关系-退出 0.9048 0.8636 0.8837 22
组织行为-开幕 0.9375 0.9375 0.9375 32
组织行为-游行 1.0000 1.0000 1.0000 9
组织行为-罢工 1.0000 1.0000 1.0000 8
组织行为-闭幕 1.0000 1.0000 1.0000 9
财经/交易-上市 0.8750 1.0000 0.9333 7
财经/交易-出售/收购 1.0000 1.0000 1.0000 24
财经/交易-加息 1.0000 1.0000 1.0000 3
财经/交易-涨价 0.8000 0.8000 0.8000 5
财经/交易-涨停 0.9643 1.0000 0.9818 27
财经/交易-融资 1.0000 1.0000 1.0000 14
财经/交易-跌停 1.0000 1.0000 1.0000 14
财经/交易-降价 0.9000 1.0000 0.9474 9
财经/交易-降息 1.0000 1.0000 1.0000 4
micro avg 0.9604 0.9650 0.9627 1657
macro avg 0.9491 0.9583 0.9522 1657
weighted avg 0.9618 0.9650 0.9626 1657
samples avg 0.9720 0.9760 0.9693 1657
运行predict.py
文件
设置阈值,当sigmoid结果大于阈值时,判定为1,否则为0
one_hot, label = predict("xxxx")
测试样本预测标签: 灾害/意外-爆炸