本文分享自华为云社区《TextBrewer:融合并改进了NLP和CV中的多种知识蒸馏技术、提供便捷快速的知识蒸馏框架、提升模型的推理速度,减少内存占用》,作者:汀丶。
TextBrewer是一个基于PyTorch的、为实现NLP中的知识蒸馏任务而设计的工具包,融合并改进了NLP和CV中的多种知识蒸馏技术,提供便捷快速的知识蒸馏框架,用于以较低的性能损失压缩神经网络模型的大小,提升模型的推理速度,减少内存占用。
TextBrewer 为NLP中的知识蒸馏任务设计,融合了多种知识蒸馏技术,提供方便快捷的知识蒸馏框架。
主要特点:
TextBrewer目前支持的知识蒸馏技术有:
TextBrewer的主要功能与模块分为3块:
用户需要准备:
在多个典型NLP任务上,TextBrewer都能取得较好的压缩效果。相关实验见蒸馏效果。
pip install textbrewer
git clone https://github.com/airaria/TextBrewer.git
pip install ./textbrewer
Stage 1 : 蒸馏之前的准备工作:
Stage 2 : 使用TextBrewer蒸馏:
在开始蒸馏之前准备:
使用TextBrewer蒸馏:
import textbrewer
from textbrewer import GeneralDistiller
from textbrewer import TrainingConfig, DistillationConfig
#展示模型参数量的统计
print("\nteacher_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(teacher_model,max_level=3)
print (result)
print("student_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(student_model,max_level=3)
print (result)
#定义adaptor用于解释模型的输出
def simple_adaptor(batch, model_outputs):
return {'logits': model_outputs[1], 'hidden': model_outputs[2]}
#蒸馏与训练配置
distill_config = DistillationConfig(
intermediate_matches=[
{'layer_T':0, 'layer_S':0, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1},
{'layer_T':8, 'layer_S':2, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1}])
train_config = TrainingConfig()
#初始化distiller
distiller = GeneralDistiller(
train_config=train_config, distill_config = distill_config,
model_T = teacher_model, model_S = student_model,
adaptor_T = simple_adaptor, adaptor_S = simple_adaptor)
#开始蒸馏
with distiller:
distiller.train(optimizer, dataloader, num_epochs=1, scheduler_class=scheduler_class, scheduler_args = scheduler_args, callback=None)
Transformers 4示例
examples/random_token_example: 一个可运行的简单示例,在文本分类任务上以随机文本为输入,演示TextBrewer用法。
examples/cmrc2018_example (中文): CMRC 2018上的中文阅读理解任务蒸馏,并使用DRCD数据集做数据增强。
examples/mnli_example (英文): MNLI任务上的英文句对分类任务蒸馏,并展示如何使用多教师蒸馏。
examples/conll2003_example (英文): CoNLL-2003英文实体识别任务上的序列标注任务蒸馏。
examples/msra_ner_example (中文): MSRA NER(中文命名实体识别)任务上,使用分布式数据并行训练的Chinese-ELECTRA-base模型蒸馏。
我们在多个中英文文本分类、阅读理解、序列标注数据集上进行了蒸馏实验。实验的配置和效果如下。
我们测试了不同的学生模型,为了与已有公开结果相比较,除了BiGRU都是和BERT一样的多层Transformer结构。模型的参数如下表所示。需要注意的是,参数量的统计包括了embedding层,但不包括最终适配各个任务的输出层。
Model
#Layers
Hidden size
Feed-forward size
#Params
Relative size
BERT-base-cased (教师)
12
768
3072
108M
100%
T6 (学生)
6
768
3072
65M
60%
T3 (学生)
3
768
3072
44M
41%
T3-small (学生)
3
384
1536
17M
16%
T4-Tiny (学生)
4
312
1200
14M
13%
T12-nano (学生)
12
256
1024
17M
16%
BiGRU (学生)
-
768
-
31M
29%
Model
#Layers
Hidden size
Feed-forward size
#Params
Relative size
RoBERTa-wwm-ext (教师)
12
768
3072
102M
100%
Electra-base (教师)
12
768
3072
102M
100%
T3 (学生)
3
768
3072
38M
37%
T3-small (学生)
3
384
1536
14M
14%
T4-Tiny (学生)
4
312
1200
11M
11%
Electra-small (学生)
12
256
1024
12M
12%
distill_config = DistillationConfig(temperature = 8, intermediate_matches = matches)
#其他参数为默认值
不同的模型用的matches我们采用了以下配置:
Model
matches
BiGRU
None
T6
L6_hidden_mse + L6_hidden_smmd
T3
L3_hidden_mse + L3_hidden_smmd
T3-small
L3n_hidden_mse + L3_hidden_smmd
T4-Tiny
L4t_hidden_mse + L4_hidden_smmd
T12-nano
small_hidden_mse + small_hidden_smmd
Electra-small
small_hidden_mse + small_hidden_smmd
各种matches的定义在examples/matches/matches.py中。均使用GeneralDistiller进行蒸馏。
蒸馏用的学习率 lr=1e-4(除非特殊说明)。训练30~60轮。
在英文实验中,我们使用了如下三个典型数据集。
Dataset
Task type
Metrics
#Train
#Dev
Note
MNLI
文本分类
m/mm Acc
393K
20K
句对三分类任务
SQuAD 1.1
阅读理解
EM/F1
88K
11K
篇章片段抽取型阅读理解
CoNLL-2003
序列标注
F1
23K
6K
命名实体识别任务
我们在下面两表中列出了DistilBERT, BERT-PKD, BERT-of-Theseus, TinyBERT 等公开的蒸馏结果,并与我们的结果做对比。
Public results:
Model (public)
MNLI
SQuAD
CoNLL-2003
DistilBERT (T6)
81.6 / 81.1
78.1 / 86.2
-
BERT6-PKD (T6)
81.5 / 81.0
77.1 / 85.3
-
BERT-of-Theseus (T6)
82.4/ 82.1
-
-
BERT3-PKD (T3)
76.7 / 76.3
-
-
TinyBERT (T4-tiny)
82.8 / 82.9
72.7 / 82.1
-
Our results:
Model (ours)
MNLI
SQuAD
CoNLL-2003
BERT-base-cased (教师)
83.7 / 84.0
81.5 / 88.6
91.1
BiGRU
-
-
85.3
T6
83.5 / 84.0
80.8 / 88.1
90.7
T3
81.8 / 82.7
76.4 / 84.9
87.5
T3-small
81.3 / 81.7
72.3 / 81.4
78.6
T4-tiny
82.0 / 82.6
75.2 / 84.0
89.1
T12-nano
83.2 / 83.9
79.0 / 86.6
89.6
说明:
在中文实验中,我们使用了如下典型数据集。
Dataset
Task type
Metrics
#Train
#Dev
Note
文本分类
Acc
393K
2.5K
MNLI的中文翻译版本,3分类任务
LCQMC
文本分类
Acc
239K
8.8K
句对二分类任务,判断两个句子的语义是否相同
阅读理解
EM/F1
10K
3.4K
篇章片段抽取型阅读理解
阅读理解
EM/F1
27K
3.5K
繁体中文篇章片段抽取型阅读理解
MSRA NER
序列标注
F1
45K
3.4K (测试集)
中文命名实体识别
实验结果如下表所示。
Model
XNLI
LCQMC
CMRC 2018
DRCD
RoBERTa-wwm-ext (教师)
79.9
89.4
68.8 / 86.4
86.5 / 92.5
T3
78.4
89.0
66.4 / 84.2
78.2 / 86.4
T3-small
76.0
88.1
58.0 / 79.3
75.8 / 84.8
T4-tiny
76.2
88.4
61.8 / 81.8
77.3 / 86.1
Model
XNLI
LCQMC
CMRC 2018
DRCD
MSRA NER
Electra-base (教师)
77.8
89.8
65.6 / 84.7
86.9 / 92.3
95.14
Electra-small
77.7
89.3
66.5 / 84.9
85.5 / 91.3
93.48
说明:
Distiller负责执行实际的蒸馏过程。目前实现了以下的distillers:
蒸馏实验中,有两个组件需要由用户提供,分别是callback 和 adaptor :
回调函数。在每个checkpoint,保存模型后会被distiller调用,并传入当前模型。可以借由回调函数在每个checkpoint评测模型效果。
将模型的输入和输出转换为指定的格式,向distiller解释模型的输入和输出,以便distiller根据不同的策略进行不同的计算。在每个训练步,batch和模型的输出model_outputs会作为参数传递给adaptor,adaptor负责重新组织这些数据,返回一个字典。
更多细节可参见完整文档中的说明。
Q: 学生模型该如何初始化?
A: 知识蒸馏本质上是“老师教学生”的过程。在初始化学生模型时,可以采用随机初始化的形式(即完全不包含任何先验知识),也可以载入已训练好的模型权重。例如,从BERT-base模型蒸馏到3层BERT时,可以预先载入RBT3模型权重(中文任务)或BERT的前三层权重(英文任务),然后进一步进行蒸馏,避免了蒸馏过程的“冷启动”问题。我们建议用户在使用时尽量采用已预训练过的学生模型,以充分利用大规模数据预训练所带来的优势。
Q: 如何设置蒸馏的训练参数以达到一个较好的效果?
A: 知识蒸馏的比有标签数据上的训练需要更多的训练轮数与更大的学习率。比如,BERT-base上训练SQuAD一般以lr=3e-5训练3轮左右即可达到较好的效果;而蒸馏时需要以lr=1e-4训练30~50轮。当然具体到各个任务上肯定还有区别,我们的建议仅是基于我们的经验得出的,仅供参考。
Q: 我的教师模型和学生模型的输入不同(比如词表不同导致input_ids不兼容),该如何进行蒸馏?
A: 需要分别为教师模型和学生模型提供不同的batch,参见完整文档中的 Feed Different batches to Student and Teacher, Feed Cached Values 章节。
Q: 我缓存了教师模型的输出,它们可以用于加速蒸馏吗?
A: 可以, 参见完整文档中的 Feed Different batches to Student and Teacher, Feed Cached Values 章节。
手机扫一扫
移动阅读更方便
你可能感兴趣的文章