NLP —— 模型优化蒸馏案例 目录一、概念二、主流四大类技术1. 模型量化2. 模型剪枝3. 低秩因式分解4. 模型蒸馏三、代码案例需求代码思路① Config文件② 教师模型文件③ 学生模型文件1 定义参数2 搭建网络层3 前向传播④ 数据预处理文件1 读取文件数据处理2 自定义数据集类3 数据二次处理 - 数据张量和掩码张量4 构造数据加载器⑤ 模型蒸馏训练1 创建数据加载器对象2 创建教师模型对象 加载已训练好的模型3 创建学生模型对象4 损失函数5 优化器6 变量训练轮次、初始化f1_score蒸馏温度T、α系数7 设置老师模型评估模式、学生模型训练模式8 训练⑥模型预测使用一、概念模型压缩在尽量不损失精度前提下减小模型参数量、显存占用、推理耗时方便部署 CPU / 移动端。目标 参数变少、模型文件变小、推理更快、显存更低。 常见落地大 BERT→小 BiLSTM二、主流四大类技术1. 模型量化pytorch中默认 float32 int64. - float16 int8 。降低精度。从而缩减模型并加速推断速度。。pytorch 中 Quantization官网API 静态、动态APIQuantization — PyTorch 2.4 documentation① 训练中量化 QAT 量化感知训练② 训练后量化1 动态量化 DQNLP领域2 静态两会 QTQ CV领域特性静态量化动态量化APIpreparequantize_dynamic适用模型CNNResNet, MobileNetNLP模型BERT, LSTM等PyTorch的动态量化只能在CPU上执行核心代码# 定义一个模型 class Model(torch.nn.Module): def __init__(self): super().__init__() self.embedded nn.Embedding(4, 128) self.rnn nn.GRU(128, 1024, batch_firstTrue) self.linear nn.Linear(1024, 10) self.dropout nn.Dropout(p0.1) def forward(self, x): x, hn self.rnn(self.embedded(x)) return self.dropout(self.linear(x))# 创建量化模型实例 # model原始模型 # qconfig_spec待量化的层参数 # dtype量化权重的目标类型 model2 torch.quantization.quantize_dynamic(modelmodel1, qconfig_spec{torch.nn.Linear, nn.GRU}, dtypetorch.qint8)2. 模型剪枝NLP中不用一般在CV中用。Pytorch中对模型剪枝的支持在torch.nn.utils.prune模块中, 分以下几种剪枝方式随机剪枝L1结构化剪枝L1非结构化剪枝全局非结构化剪枝非结构化剪枝结构化剪枝按单个权重裁剪按神经元、通道、整行/列裁剪剪枝后是稀疏矩阵剪枝后是稠密矩阵类似于裁掉部门中贡献度低的个人类似于裁掉整个部门代码# 演示随机非结构化剪枝 def dm01(): linear nn.Linear(2, 3) print(linear--, linear.weight) model prune.random_unstructured(linear, weight, amount2) print(model--, model.weight) # 演示全局非结构化剪枝 def dm02(): net nn.Sequential(OrderedDict([ (first, nn.Linear(3, 4)), (second, nn.Linear(4, 2)), ])) print(net1--, net) for model in net: print(model--, model.weight) parameters_to_prune ((net.first, weight), (net.second, weight)) # parameters_to_prune待剪枝的参数 # pruning_method剪枝的方式L1Unstructured表示非结构化剪枝常用 # amount如果是小数则表示比例如果是整数则表示数量 prune.global_unstructured(parameters_to_prune, pruning_methodprune.L1Unstructured, amount 0.2) print(net2--, net) for model in net: print(model--, model.weight)3. 低秩因式分解比如21128词表 * 768维度 很大进行分解。运用矩阵分解减少网络参数量提升效率。4. 模型蒸馏复杂模型教师模型- 简单模型学生模型教师模型定义复杂的、高性能的模型通常是大型深度神经网络。特点参数量大能够学习复杂的特征和关系。需要提前训练好。学生模型定义简化的、小型的模型可以是教师模型的子集或者简单模型。特点参数量较小适用于资源受限的场景。不需要提前训练好。知识的来源硬标签蒸馏学生模型直接学习教师模型的分类结果。软标签蒸馏学生模型学习教师模型对每个类别的概率分布。中间层蒸馏学生模型学习教师模型的隐藏层、特征图等。关键点高温T平滑输出概率生成软标签效果BERT (110M 参数) → BiLSTM (几 M 参数)体积压缩十几倍损失 真实标签 CE 损失 KL 蒸馏损失适用NLP 分类、文本任务。公式# 计算KL散度值 p torch.log_softmax(teacher_pred/T, dim-1) q torch.log_softmax(student_pred/T, dim-1) # KL散度值也就是软标签的值 参数解释 input是【学生模型】输出的结果 target预测结果参考值。也就是【教师模型】输出的结果 reduction上面两个值的计算方式。 log_target是否对计算结果求log对数 kl_value torch.nn.functional.kl_div( inputq, targetp, reductionbatchmean, log_targetTrue ) # 硬标签损失值 # 注意是学生模型的预测概率与样本的目标值算损失 hard_label_loss loss(student_pred,labels) # 蒸馏的总损失值 # l (1-α) * 硬标签损失值 α * T² * KL散度值 distll_loss (1 - alpha) * hard_label_loss alpha * (T**2) * kl_valueq: 学生模型预测结果计算得来p: 教师模型预测结果计算得来CEy,p也就是 学生模型自己的交叉熵损失参数α系数控制从学生模型和教师模型学习的比例比如α0.8。参数T蒸馏温度是一个平滑系数控制softmax的输出比如T4。蒸馏总损失值 L_{KD} (1 - α)CE(y,p) αKL(q,p)KLDivLoss — PyTorch 2.4 documentation三、代码案例需求以文本分类任务基于Bert模型的 教师模型学生模型内部使用BiLstm神经网络数据文本 ( 内容 类别索引 )数据源三个内容文件一个类别文件。代码思路① Config文件配置各个文件路径数据源模型批次大小句子最大长度class Config(object): def __init__(self): # 1 - 设备 # self.device torch.device(cuda if torch.cuda.is_available() else cpu) self.device cpu # 2 原始文件 self.train_datapath data/train.txt self.test_datapath data/test.txt self.dev_datapath data/dev.txt self.class_datapath data/class.txt # 3 数据加载参数 self.batch_size 64 self.max_seq_len 32 # 4 Bert 预训练模型路径 self.bert_path ../Base_Bert_TMF/bert_base_model/bert-base-chinese # 5 - 目标值 文本解析 self.classname_list [line.strip() for line in open(self.class_datapath,moder,encodingutf-8)] self.classname_len len(self.classname_list) # 6 - 训练好的【教师模型】路径 self.teacher_model_path save_model/teacher_bert.pkl # 7 - 学生模型路径 self.student_model_path save_model/student_model.pkl② 教师模型文件基于Bert模型经过线性层处理冻结反向传播。已训练好的模型线性层in_features Bert模型的隐藏状态大小,out_features数据源类的总共个数 教师模型基于Bert模型 import torch import torch.nn as nn from transformers import BertModel from transformers import BertConfig from config import Config config Config() class TeacherBertModel(nn.Module): def __init__(self): super().__init__() self.bert_model BertModel.from_pretrained(config.bert_path) temp_config BertConfig.from_pretrained(config.bert_path) in_features temp_config.hidden_size self.linear nn.Linear( in_featuresin_features, out_featuresconfig.classname_len ) def forward(self, input_ids, attention_maskNone): # 教师模型不需要训练 要冻结反向传播 with torch.no_grad(): bert_output self.bert_model( input_idsinput_ids, attention_maskattention_mask ) # 2- 教师模型的池化层实际就是nn.Linear激活函数。不用额外定义 1- last_hidden_state[:,0]和pooler_output实际是类似的东西都表示[CLS]的隐藏状态。 区别需要对last_hidden_state[:,0]经过nn.Linear和激活函数处理后才能得到pooler_output 对应源代码位置BertModel文件的697行 2- 获得池化层后的结果有两种方式 2.1- 方式一推荐。通过实例属性获得 bert_output.pooler_output 2.2- 方式二通过实例属性索引获得 bert_output[1]。1的原因是pooler_output是类中的第2个实例属性 对应源代码位置BertModel文件的1017行 # 因为是句子 分类问题所以取句子的向量。 pooled_output bert_output.pooler_output return self.linear(pooled_output)③ 学生模型文件定义学生模型类1 定义参数词汇表大小词向量维度隐藏状态隐藏层层数2 搭建网络层词向量层、双向LSTM、随机失活层、线性层输入 2倍的隐藏大小输出 句子最大长度3 前向传播1 数据张量化2 输入原始数据处理过滤【CLS、SEP】特殊标识基于transformer系列都有这个标识。结合输入掩码张量对原始数据矩阵点乘处理得到最终有效的词张量数据3 调用BiLstm循环神经网络 - 得到输出数据【batch_size,seq_lenhidden_size】4 因为是文本分类需要的是句子对输出数据累加-降维-记得句向量数据5调用随机失活 线性层- 输出 学生模型 用BILSTM 双向模型 from torch import Tensor from config import Config import torch import torch.nn as nn from transformers import BertConfig config Config() bert_config BertConfig.from_pretrained(config.bert_path) class BILSTMStudentModel(nn.Module): def __init__(self): super().__init__() 设置参数 基于Bert模型的中文词汇表大小 self.vocab_size bert_config.vocab_size self.embedding_dim 128 self.hidden_size 256 self.num_layers 3 搭建网络层 embedding_dim由我们自己设置与教师模型没有任何关系 self.embedding nn.Embedding(self.vocab_size, self.embedding_dim) self.lstm nn.LSTM( input_sizeself.embedding_dim, #输入的词向量维度必须和embding_dim 相同 hidden_sizeself.hidden_size, #隐藏层向量维度 自定义 batch_firstTrue, #是否batch_size开头的张量 【batch_size,seq_len,hidden_size】 num_layersself.num_layers, #隐藏层层数 bidirectionalTrue #是否双向 ) self.dropout nn.Dropout(p0.2) 因为双向LSTM 所以 hidden_size*2 多分类任务任务值是 取数据类别个数 作为输出 self.linear nn.Linear(self.hidden_size*2, config.classname_len) def forward(self, input_ids, attention_mask): # 1 - 数据张量化 ebd self.embedding(input_ids) 带 【CLS、SEP】特殊标识 TokenBERT 系 Transformer 编码器网络 所以数据要先把 【CLS】、【SEP】标识去除 # 2 - cls_token_index 101 #句子开头 CLS固定索引值 sep_token_index 102 #句子结尾 SEP固定索引值 # 2.1 # 对 input_ids 数据过滤 CLS 和 SEP ebd_mask (input_ids ! cls_token_index) (input_ids ! sep_token_index) # 2.2 # 过滤后的数据 与 掩码进行再次过滤 得到实际要用的掩码 ebd_mask:Tensor ebd_mask attention_mask # 2.3 # 对 edb_mask 升维 # 原始【batch_size,seq_len】 - 【batch_size, seq_len, 1】 ebd_mask ebd_mask.unsqueeze(-1) # 2.4 # 原始数据 与 实际掩码 进行点乘预算得到实际有效的数据源 ebd ebd * ebd_mask # 3 - 调用循环神经网络BiLSTM # 为什么调用lstm的时候没有手动传递初始的细胞状态和隐藏状态LSTM内部会自动的进行全0初始化。源代码在1056行 out_put, (hidden, c) self.lstm(ebd) # 4 - 计算平均池化值 # 4.1 # 降维 以为是对词向量进行 网络处理需求做的是句子分类 # 【batch_sizeseq_len,hidden_size】 [batch_size, hidden_size] output_sum out_put.sum(dim1) # 4.2 # 获取所有有效词的个数 1e-6 为了防止个数为0 token_count ebd_mask.sum(dim1) 1e-6 # 4.3 # 计算获取 最终的句子向量数据 new_output output_sum / token_count # 5 # 调用线性层得到预测结构并返回 return self.linear(self.dropout(new_output))④ 数据预处理文件1 读取文件数据处理表格数据读取 - 得到数组 每行的数据2 自定义数据集类1 __init__ 参数定义 self.data_list 1处理得到的2 __len__ 样本条数3 __getitem__ 函数根据索引获得 对应的 文本和分类 值3 数据二次处理 - 数据张量和掩码张量1 传入每批次数据输入数据[(近期新盘推荐 通州纯新别墅本周开盘, 1), (陕西退休教师嫌弃精神病 女儿将其勒死被捕, 5)]输出数据[(近期新盘推荐 通州纯新别墅本周开盘, 陕西退休教师嫌弃精神病女儿 将其勒死被捕), (1, 5)]得到 文本内容元组 和 类别元组2 通过 transformers 的 BertTokenizer, 把数据转换为词索引张量3 返回 数据张量(intput_ids)、掩码张量(attention_mask)、真实类别张量(lables)4 构造数据加载器1 通过1、2、得到数据集2 创建数据加载器对象 DataLoader3 返回加载器对象 数据处理 得到模型需要的 input_dis 和 attention_mask. 并传递 真实值 Labels # 1 读取文件获得数据 # 2 定义数据集 # 3 数据二次处理 (按batch处理成input_disattention_mask 张量) # 4 构建数据加载器 import torch import torch.nn as nn from config import Config from torch.utils.data import Dataset,DataLoader from transformers import BertTokenizer config Config() bert_tokenizer BertTokenizer.from_pretrained(config.bert_path) # 1 - 数据获取处理 def load_data(datapath): with open(datapath,moder,encodingUTF-8) as f: lines f.readlines() result_list [] for line in lines: line line.strip() if line: continue # 样本数据 # 两天价网站背后重重迷雾做个网站究竟要多少钱 4 title, label line.split(\t) # 【可选】健壮性代码 只要是有数据类型转换的地方基本都有健壮性代码 if not label.isdigit(): print(flabel的数据内容不合法值是{label}) continue # 保存数据 result_list.append((title,int(label))) return result_list # 2 - 自定义数据集 class NewsDataset(Dataset): def __init__(self,data_list): super().__init__() self.data_list data_list #读取数据 self.sample_len len(self.data_list) #样本条数 def __len__(self): return self.sample_len def __getitem__(self, idx): # 防止数组越界 index min(max(idx, 0),self.sample_len-1) title,label self.data_list[index] return title,label # 3 - 数据二次处理,按每批次数据处理 def collate_fn(batch_data): zip(*)处理过程如下 输入数据[(近期新盘推荐 通州纯新别墅本周开盘, 1), (陕西退休教师嫌弃精神病女儿将其勒死被捕, 5)] 输出数据[(近期新盘推荐 通州纯新别墅本周开盘, 陕西退休教师嫌弃精神病女儿将其勒死被捕), (1, 5)] titles,labels zip(*batch_data) # 根据词索引 数据张量化 - 获取词索引张量 title_tensor bert_tokenizer( titles, paddingmax_length, truncationTrue, max_lengthconfig.max_seq_len, return_tensorspt ) return ( title_tensor.input_ids, title_tensor.attention_mask, torch.tensor(labels,dtypetorch.long) ) # 4 - 构建数据加载器 def build_dataloader(datapath, shuffleTrue): data load_data(datapath) dataset NewsDataset(data) data_loader DataLoader( datasetdataset, batch_sizeconfig.batch_size, shuffleshuffle, collate_fncollate_fn ) return data_loader⑤ 模型蒸馏训练学生模型训练边训练边预测保存1 创建数据加载器对象2 创建教师模型对象 加载已训练好的模型3 创建学生模型对象4 损失函数5 优化器6 变量训练轮次、初始化f1_score蒸馏温度T、α系数7 设置老师模型评估模式、学生模型训练模式8 训练8.1 根据数据加载器分批次 获取输入张量、掩码张量、真实类别张量8.2 模型前向传播其中老师模型冻结不需要更新8.3 计算KL散度8.4 计算学生模型交叉熵损失值8.5 计算蒸馏总损失值8.6 梯度清零、反向传播、梯度更新8.7 每固定间隔 对学生模型进行评估1 数据加载器加载评估数据2 学生模型切换评估模式3 数据加载器分批次进行模型评估保存真实结果和评估结果4 计算评估指标f1_score、accuracy(准确率)、precision(精确率)、recall(召回率)8.8 f1_socre 上一次的f1_socre 值保存模型进行覆盖。8.9 学生模型切换训练模型继续训练直到所有训练数据结束 模型蒸馏 import torch import torch.nn as nn from tqdm import tqdm from data_preprocessing import build_dataloader from student_bilstm_model import BILSTMStudentModel from teacher_bert_model import TeacherBertModel from config import Config from sklearn.metrics import accuracy_score,precision_score,recall_score,f1_score config Config() def eval(student_model): # 1. 数据加载器 dataloader build_dataloader(config.dev_datapath, shuffleFalse) # 2. 切换模式 student_model.eval() all_pred_result [] # 预测结果列表 all_true_result [] # 真实结果列表 # 3. 预测 with torch.no_grad(): for batch_idx, batch_data in enumerate(tqdm(dataloader),start1): input_dis, attention_mask, labels batch_data input_dis input_dis.to(config.device) attention_mask attention_mask.to(config.device) labels labels.to(config.device) # 预测结果 student_pred student_model(input_dis, attention_mask) student_pred_index torch.argmax(student_pred, dim-1) # cpu()因为不涉及张量的计算因此为了节约GPU资源可以将数据转到CPU上再处理 # .tolist() tensor([0,2,1]) → [0,2,1] # .extend() # append([1,2,3]) → [[1,2,3]]嵌套列表 # extend([1,2,3]) → [1,2,3]把元素挨个拼进去 all_pred_result.extend(student_pred_index.cpu().tolist()) all_true_result.extend(labels.cpu().tolist()) # 4 - 计算评估指标 f1score f1_score(all_true_result,all_pred_result,averagemacro) # 准确率 accuracy accuracy_score(all_true_result,all_pred_result) precision precision_score(all_true_result,all_pred_result,averagemacro) recall recall_score(all_true_result,all_pred_result,averagemacro) return f1score, accuracy, precision, recall def train_and_eval(): # 1. 通过加载器获取数据 data_loader build_dataloader(config.train_datapath, shuffleTrue) # 2 - 教师模型 teacher_model TeacherBertModel().to(config.device) teacher_model.load_state_dict(torch.load(config.teacher_model_path)) # 3 - 学生模型 student_model BILSTMStudentModel().to(config.device) # 4 - 损失函数 loss_fn nn.CrossEntropyLoss() # 5 - 优化器 optimizer torch.optim.Adam(student_model.parameters(), lr5e-5) # 6 - 其他变量 epochs 1 best_f1score 0 T 2 #蒸馏温度 alpha 0.7 #计算蒸馏总损失 KL散度和学生 概率比例 # 7 - 训练模式 student_model.train() teacher_model.eval() # 8 训练 for epoch in range(epochs): for batch_idx, batch_data in enumerate(tqdm(data_loader),start1): input_dis, attention_mask, labels batch_data # 8.1 批次训练数据 # 输入张量、掩码张量、真实张量 input_dis input_dis.to(config.device) attention_mask attention_mask.to(config.device) labels labels.to(config.device) # 8.2 模型前向传播 # 老师模型冻结不需要更新 with torch.no_grad(): teacher_pred teacher_model(input_dis, attention_mask) teacher_pred_labels torch.argmax(teacher_pred, dim-1) student_pred student_model(input_dis, attention_mask) student_pred_labels torch.argmax(student_pred, dim-1) # 8.3 # 计算KL散度 p torch.log_softmax(teacher_pred/T, dim-1) q torch.log_softmax(student_pred/T, dim-1) # KL散度值也就是软标签的值 注意kl_div的包不要导错了 参数解释 input是【学生模型】输出的结果 target预测结果参考值。也就是【教师模型】输出的结果 reduction上面两个值的计算方式。 log_target是否对计算结果求log对数 kl_value torch.nn.functional.kl_div( inputq, targetp, reductionbatchmean, log_targetTrue ) # 8.4 学生模型自己的损失值 loss_value loss_fn(student_pred, labels) # 8.5 蒸馏总损失值 固定公式 distill_loss (1-alpha) * loss_value alpha * kl_value * (T**2) # 8.6 梯度清零反向传播梯度更新 optimizer.zero_grad() distill_loss.backward() optimizer.step() # 8.7 每间隔100个批次 或者 最后一个批次对学生模型进行验证 if batch_idx%1000 or batch_idxlen(data_loader): f1_score, accuracy, precision, recall eval(student_model) print(f第{batch_idx}批次f1score{f1_score}accuracy{accuracy}precision{precision}recall{recall}) if f1_score best_f1score: torch.save(student_model.state_dict(), config.student_model_path) best_f1score f1_score # 切换回训练模式 student_model.train() if __name__ __main__: train_and_eval()⑥模型预测使用 预测函数 提供模型服务 import torch from config import Config from transformers import BertTokenizer from student_bilstm_model import BILSTMStudentModel config Config() model BILSTMStudentModel().to(config.device) model.load_state_dict(torch.load(config.student_model_path)) model.eval() tokenizer BertTokenizer.from_pretrained(config.bert_path) def model_predict(json_data): # 1 - 外部数据 取得句子 title json_data[title] # 2 - 文本转张量 获得 input_ids, attention_mask title_tensor tokenizer( [title], paddingmax_length, truncationTrue, max_lengthconfig.max_seq_len, return_tensorspt ) input_ids title_tensor.input_ids.to(config.device) attention_mask title_tensor.attention_mask.to(config.device) with torch.no_grad(): output model(input_ids, attention_mask) output_index torch.argmax(output, dim-1).item() #取概率最大的索引值 pred_class_name config.classname_list[output_index] json_data[pred_class] pred_class_name return json_data if __name__ __main__: print(model_predict({title: 体验2D巅峰 倚天屠龙记十大创新新概览}))