基于Transformers+CNN/LSTM/GRU的文本分类
点击上方,选择星标或置顶,每天给你送干货
!
阅读大概需要5分钟跟随小博主,每天进步一丢丢
Transformers_for_Text_Classification
本项目的Github地址
https://github.com/zhanlaoban/Transformers_for_Text_Classification
基于Transformers的文本分类
基于最新的 huggingface 出品的 transformers v2.2.2代码进行重构。为了保证代码日后可以直接复现而不出现兼容性问题,这里将 transformers 放在本地进行调用。
Highlights
支持transformer模型后接各种特征提取器 支持测试集预测代码 精简原始transformers代码,使之更适合文本分类任务 优化logging终端输出,使之输出内容更加合理
Support
model_type:
[✔] bert [✔] bert+cnn [✔] bert+lstm [✔] bert+gru [✔] xlnet [ ] xlnet+cnn [✔] xlnet+lstm [✔] xlnet+gru [ ] albert
Content
dataset:存放数据集 pretrained_models:存放预训练模型 transformers:transformers文件夹 results:存放训练结果
Usage
1. 使用不同模型
在shell文件中修改
model_type
参数即可指定模型如,BERT后接FC全连接层,则直接设置
model_type=bert
;BERT后接CNN卷积层,则设置model_type=bert_cnn
.在本README的
Support
中列出了本项目中各个预训练模型支持的model_type
。最后,在终端直接运行shell文件即可,如:
bash run_classifier.sh
注:在中文RoBERTa、ERNIE、BERT_wwm这三种预训练语言模型中,均使用BERT的model_type进行加载。
2. 使用自定义数据集
在 dataset
文件夹里存放自定义的数据集文件夹,如TestData
.在根目录下的 utils.py
中,仿照class THUNewsProcessor
写一个自己的类,如命名为class TestDataProcessor
,并在tasks_num_labels
,processors
,output_modes
三个dict中添加相应内容.最后,在你需要运行的shell文件中修改TASK_NAME为你的任务名称,如 TestData
.
Environment
- one 2080Ti, 12GB RAM
- Python: 3.6.5
- PyTorch: 1.3.1
- TensorFlow: 1.14.0(仅为了支持TensorBoard,无其他作用)
- Numpy: 1.14.6
Performance
数据集: THUNews/5_5000
epoch:1
train_steps: 5000
model | dev set best F1 and Acc | remark |
---|---|---|
bert_base | 0.9308, 0.9324 | BERT接FC层, batch_size 8, learning_rate 2e-5 |
bert_base+cnn | 0.9136, 0.9156 | BERT接CNN层, batch_size 8, learning_rate 2e-5 |
bert_base+lstm | 0.9369, 0.9372 | BERT接LSTM层, batch_size 8, learning_rate 2e-5 |
bert_base+gru | 0.9379, 0.938 | BERT接GRU层, batch_size 8, learning_rate 2e-5 |
roberta_large | RoBERTa接FC层, batch_size 2, learning_rate 2e-5 | |
xlnet_large | 0.9530, 0.954 | XLNet接FC层, batch_size 2, learning_rate 2e-5 |
xlnet_mid+lstm | 0.9269, 0.9304 | XLNet接LSTM层, batch_size 2, learning_rate 2e-5 |
xlnet_mid+gru | 0.9494, 0.9508 | XLNet接GRU层, batch_size 2, learning_rate 2e-5 |
albert_xlarge_183k |
Download Chinese Pre-trained Models
NPL_PEMDC(https://github.com/zhanlaoban/NLP_PEMDC)
欢迎大家踊跃投稿!
最新评论
推荐文章
作者最新文章
你可能感兴趣的文章
Copyright Disclaimer: The copyright of contents (including texts, images, videos and audios) posted above belong to the User who shared or the third-party website which the User shared from. If you found your copyright have been infringed, please send a DMCA takedown notice to [email protected]. For more detail of the source, please click on the button "Read Original Post" below. For other communications, please send to [email protected].
版权声明:以上内容为用户推荐收藏至CareerEngine平台,其内容(含文字、图片、视频、音频等)及知识版权均属用户或用户转发自的第三方网站,如涉嫌侵权,请通知[email protected]进行信息删除。如需查看信息来源,请点击“查看原文”。如需洽谈其它事宜,请联系[email protected]。
版权声明:以上内容为用户推荐收藏至CareerEngine平台,其内容(含文字、图片、视频、音频等)及知识版权均属用户或用户转发自的第三方网站,如涉嫌侵权,请通知[email protected]进行信息删除。如需查看信息来源,请点击“查看原文”。如需洽谈其它事宜,请联系[email protected]。