别再死记硬背CRF公式了!用TensorFlow 2.x手写一个命名实体识别(NER)层,从代码反推原理 从零实现TensorFlow 2.x CRF层代码反推NER核心原理在自然语言处理领域命名实体识别(NER)任务常采用条件随机场(CRF)作为解码层。但大多数教程停留在数学公式推导让开发者陷入看得懂推不通推得通写不出的困境。本文将以代码实现为主导通过手写TensorFlow 2.x的CRF层逆向解析其核心原理。我们将从定义发射分数和转移矩阵开始逐步实现前向计算、损失函数和维特比解码最终整合成可复用的CRF层模块。1. CRF核心概念与实现准备CRF作为判别式概率模型其核心在于考虑相邻标签间的转移特性。例如在BIO标注体系中B-PER后面只能接I-PER或O而不能接B-ORG。这种强约束特性使CRF成为序列标注任务的首选。实现CRF层需要三个关键组件发射分数(Emission Scores)由上层模型输出的每个标签的未归一化分数转移矩阵(Transition Matrix)存储标签间转移概率的参数矩阵维特比算法(Viterbi)计算最优标签序列的动态规划算法先导入必要库并定义超参数import tensorflow as tf from tensorflow.keras.layers import Layer class CRF(Layer): def __init__(self, num_tags, **kwargs): super(CRF, self).__init__(**kwargs) self.num_tags num_tags # 标签数量 self.transitions tf.Variable( tf.random.uniform(shape(num_tags, num_tags)), nametransitions, trainableTrue)这里num_tags表示标签数量transitions是随机初始化的可训练转移矩阵。例如BIO标注体系有3个标签则num_tags3。提示转移矩阵的维度是[标签数量, 标签数量]每个元素transitions[i][j]表示从标签i转移到标签j的分数2. 实现CRF的前向计算前向计算需要完成两项工作计算所有可能路径的分数用于训练时的损失计算和计算真实路径的分数。定义前向计算函数def call(self, inputs, targetsNone, sequence_lengthsNone, trainingNone): emissions inputs # 上层模型输出的发射分数 [batch_size, seq_len, num_tags] if training and targets is not None: # 训练阶段计算损失 log_likelihood self._compute_log_likelihood(emissions, targets, sequence_lengths) self.add_loss(-log_likelihood) # 预测阶段返回维特比解码结果 return self._viterbi_decode(emissions, sequence_lengths)其中_compute_log_likelihood函数计算负对数似然损失_viterbi_decode函数实现维特比解码算法。2.1 计算真实路径分数真实路径分数的计算需要考虑发射分数和转移分数def _compute_log_likelihood(self, emissions, tags, sequence_lengths): batch_size tf.shape(emissions)[0] seq_len tf.shape(emissions)[1] num_tags tf.shape(emissions)[2] # 创建掩码处理变长序列 mask tf.sequence_mask(sequence_lengths, maxlenseq_len, dtypetf.float32) # 计算发射分数 emit_scores tf.gather_nd(emissions, tf.stack([tf.range(batch_size)[:, None], tf.range(seq_len)[None, :], tags], axis-1)) emit_scores tf.reduce_sum(emit_scores * mask, axis1) # 计算转移分数 tags_transposed tf.transpose(tags, perm[1, 0]) prev_tags tf.concat([tf.fill([1, batch_size], -1), tags_transposed[:-1]], axis0) prev_tags tf.transpose(prev_tags, perm[1, 0]) transition_scores tf.gather_nd(self.transitions, tf.stack([prev_tags, tags], axis-1)) transition_scores tf.reduce_sum(transition_scores * mask, axis1) # 计算序列开始和结束的分数 start_tags tf.gather(tags, [0], axis1) start_scores tf.gather_nd(self.transitions, tf.stack([tf.zeros_like(start_tags), start_tags], axis-1)) end_tags tf.gather(tags, sequence_lengths-1, axis1, batch_dims1) end_scores tf.gather_nd(self.transitions, tf.stack([end_tags, tf.fill(tf.shape(end_tags), self.num_tags-1)], axis-1)) # 计算对数似然 log_numerator tf.reduce_sum(emit_scores transition_scores start_scores end_scores) log_denominator self._compute_log_partition_function(emissions, sequence_lengths) return log_numerator - log_denominator2.2 计算所有路径分数配分函数配分函数Z的计算采用动态规划方法避免直接计算所有可能路径的高复杂度def _compute_log_partition_function(self, emissions, sequence_lengths): batch_size, seq_len, num_tags tf.unstack(tf.shape(emissions)) # 初始化前向变量alpha log_alpha tf.TensorArray(tf.float32, sizeseq_len) init_alpha tf.fill([batch_size, num_tags], -1e4) init_alpha tf.tensor_scatter_nd_update(init_alpha, [[i, 0] for i in range(batch_size)], tf.zeros(batch_size)) log_alpha log_alpha.write(0, init_alpha) # 递归计算前向变量 mask tf.sequence_mask(sequence_lengths, maxlenseq_len, dtypetf.float32) emissions_t tf.transpose(emissions, perm[1, 0, 2]) def loop_fn(i, log_alpha): prev_log_alpha log_alpha.read(i-1) curr_emissions emissions_t[i] # 广播相加prev_log_alpha [batch, num_tags] transitions [num_tags, num_tags] # 得到 [batch, num_tags, num_tags] log_alpha_i prev_log_alpha[:, None] self.transitions[None, :, :] log_alpha_i curr_emissions[:, None, :] # logsumexp沿最后一个维度计算 log_alpha_i tf.reduce_logsumexp(log_alpha_i, axis-1) # 应用掩码 log_alpha_i log_alpha_i * mask[:, i, None] log_alpha.read(i-1) * (1 - mask[:, i, None]) return log_alpha.write(i, log_alpha_i) # 执行循环 for i in tf.range(1, seq_len): log_alpha loop_fn(i, log_alpha) # 最终计算配分函数 log_alpha_final log_alpha.read(seq_len-1) log_z tf.reduce_logsumexp(log_alpha_final self.transitions[:, -1], axis-1) return tf.reduce_sum(log_z)3. 维特比解码算法实现预测阶段需要找到分数最高的标签序列这可以通过维特比算法实现def _viterbi_decode(self, emissions, sequence_lengths): batch_size, seq_len, num_tags tf.unstack(tf.shape(emissions)) # 初始化维特比变量 viterbi tf.TensorArray(tf.float32, sizeseq_len) init_viterbi tf.fill([batch_size, num_tags], -1e4) init_viterbi tf.tensor_scatter_nd_update(init_viterbi, [[i, 0] for i in range(batch_size)], tf.zeros(batch_size)) viterbi viterbi.write(0, init_viterbi) # 初始化反向指针 backpointers tf.TensorArray(tf.int32, sizeseq_len) init_ptrs tf.fill([batch_size, num_tags], 0) backpointers backpointers.write(0, init_ptrs) # 递归计算 emissions_t tf.transpose(emissions, perm[1, 0, 2]) mask tf.sequence_mask(sequence_lengths, maxlenseq_len, dtypetf.float32) def loop_fn(i, viterbi, backpointers): prev_viterbi viterbi.read(i-1) curr_emissions emissions_t[i] # 广播相加prev_viterbi [batch, num_tags] transitions [num_tags, num_tags] # 得到 [batch, num_tags, num_tags] curr_viterbi prev_viterbi[:, None] self.transitions[None, :, :] curr_viterbi curr_emissions[:, None, :] # 记录最大值和反向指针 max_logp tf.reduce_max(curr_viterbi, axis-1) argmax_tags tf.argmax(curr_viterbi, axis-1, output_typetf.int32) # 应用掩码 masked_max_logp max_logp * mask[:, i, None] prev_viterbi * (1 - mask[:, i, None]) masked_argmax_tags argmax_tags * tf.cast(mask[:, i, None], tf.int32) backpointers.read(i-1) * (1 - tf.cast(mask[:, i, None], tf.int32)) viterbi viterbi.write(i, masked_max_logp) backpointers backpointers.write(i, masked_argmax_tags) return viterbi, backpointers # 执行循环 for i in tf.range(1, seq_len): viterbi, backpointers loop_fn(i, viterbi, backpointers) # 回溯找到最优路径 def get_best_path(i, best_tags, backpointers_t): best_tags tf.concat([tf.expand_dims(backpointers_t[i, tf.range(batch_size), best_tags[:, 0]], 1), best_tags], axis1) return i-1, best_tags, backpointers_t backpointers_t tf.transpose(backpointers.stack(), perm[1, 0, 2]) best_tags tf.expand_dims(tf.argmax(viterbi.read(seq_len-1), axis-1, output_typetf.int32), 1) _, best_tags, _ tf.while_loop( lambda i, *_: i 1, get_best_path, (seq_len-1, best_tags, backpointers_t) ) return best_tags4. 整合CRF层与模型训练将实现的CRF层整合到模型中以BIO标注任务为例class NERModel(tf.keras.Model): def __init__(self, num_tags): super().__init__() self.embedding tf.keras.layers.Embedding(10000, 128) self.bilstm tf.keras.layers.Bidirectional( tf.keras.layers.LSTM(64, return_sequencesTrue) ) self.dense tf.keras.layers.Dense(num_tags) self.crf CRF(num_tags) def call(self, inputs, targetsNone, trainingNone): x self.embedding(inputs) x self.bilstm(x) logits self.dense(x) if training: return self.crf(logits, targets) else: return self.crf(logits)训练时直接使用CRF层计算出的负对数似然作为损失函数model NERModel(num_tags3) optimizer tf.keras.optimizers.Adam(0.001) tf.function def train_step(x, y, lengths): with tf.GradientTape() as tape: logits model(x, y, trainingTrue) loss sum(model.losses) grads tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) return loss预测时使用维特比解码得到最优标签序列def predict(model, x, lengths): logits model(x, trainingFalse) return logits5. CRF实现中的关键细节5.1 处理变长序列NER任务中文本长度不一需要正确处理变长序列。我们通过sequence_lengths参数和掩码机制实现# 创建变长序列的掩码 sequence_lengths tf.constant([10, 7, 12], dtypetf.int32) # 三个样本的实际长度 max_len tf.reduce_max(sequence_lengths) mask tf.sequence_mask(sequence_lengths, maxlenmax_len, dtypetf.float32) # 应用掩码 transition_scores transition_scores * mask5.2 数值稳定性CRF计算涉及大量指数运算容易导致数值不稳定。我们采用log空间计算提升稳定性# 原始空间计算不稳定 exp_scores tf.exp(some_scores) sum_exp tf.reduce_sum(exp_scores) prob exp_scores / sum_exp # log空间计算稳定 log_scores some_scores log_sum_exp tf.reduce_logsumexp(log_scores) log_prob log_scores - log_sum_exp5.3 转移矩阵约束某些标签转移在业务中不可能发生如B→I的非法转移可通过约束转移矩阵实现# 定义不可能转移的掩码 constraint_mask tf.constant([[0, 1, 1], # B不能转移到B [1, 0, 1], # I不能转移到I [1, 1, 0]], dtypetf.float32) # 训练前应用约束 self.transitions.assign(self.transitions * constraint_mask)6. 完整代码示例以下是完整可运行的TensorFlow 2.x CRF层实现import tensorflow as tf from tensorflow.keras.layers import Layer class CRF(Layer): def __init__(self, num_tags, **kwargs): super(CRF, self).__init__(**kwargs) self.num_tags num_tags self.transitions tf.Variable( tf.random.uniform(shape(num_tags, num_tags)), nametransitions, trainableTrue) def call(self, inputs, targetsNone, sequence_lengthsNone, trainingNone): emissions inputs if training and targets is not None: log_likelihood self._compute_log_likelihood(emissions, targets, sequence_lengths) self.add_loss(-log_likelihood) return self._viterbi_decode(emissions, sequence_lengths) def _compute_log_likelihood(self, emissions, tags, sequence_lengths): batch_size tf.shape(emissions)[0] seq_len tf.shape(emissions)[1] num_tags tf.shape(emissions)[2] mask tf.sequence_mask(sequence_lengths, maxlenseq_len, dtypetf.float32) # 计算发射分数 emit_scores tf.gather_nd(emissions, tf.stack([tf.range(batch_size)[:, None], tf.range(seq_len)[None, :], tags], axis-1)) emit_scores tf.reduce_sum(emit_scores * mask, axis1) # 计算转移分数 tags_transposed tf.transpose(tags, perm[1, 0]) prev_tags tf.concat([tf.fill([1, batch_size], -1), tags_transposed[:-1]], axis0) prev_tags tf.transpose(prev_tags, perm[1, 0]) transition_scores tf.gather_nd(self.transitions, tf.stack([prev_tags, tags], axis-1)) transition_scores tf.reduce_sum(transition_scores * mask, axis1) # 计算序列开始和结束的分数 start_tags tf.gather(tags, [0], axis1) start_scores tf.gather_nd(self.transitions, tf.stack([tf.zeros_like(start_tags), start_tags], axis-1)) end_tags tf.gather(tags, sequence_lengths-1, axis1, batch_dims1) end_scores tf.gather_nd(self.transitions, tf.stack([end_tags, tf.fill(tf.shape(end_tags), self.num_tags-1)], axis-1)) # 计算对数似然 log_numerator tf.reduce_sum(emit_scores transition_scores start_scores end_scores) log_denominator self._compute_log_partition_function(emissions, sequence_lengths) return log_numerator - log_denominator def _compute_log_partition_function(self, emissions, sequence_lengths): batch_size, seq_len, num_tags tf.unstack(tf.shape(emissions)) log_alpha tf.TensorArray(tf.float32, sizeseq_len) init_alpha tf.fill([batch_size, num_tags], -1e4) init_alpha tf.tensor_scatter_nd_update(init_alpha, [[i, 0] for i in range(batch_size)], tf.zeros(batch_size)) log_alpha log_alpha.write(0, init_alpha) mask tf.sequence_mask(sequence_lengths, maxlenseq_len, dtypetf.float32) emissions_t tf.transpose(emissions, perm[1, 0, 2]) def loop_fn(i, log_alpha): prev_log_alpha log_alpha.read(i-1) curr_emissions emissions_t[i] log_alpha_i prev_log_alpha[:, None] self.transitions[None, :, :] log_alpha_i curr_emissions[:, None, :] log_alpha_i tf.reduce_logsumexp(log_alpha_i, axis-1) log_alpha_i log_alpha_i * mask[:, i, None] log_alpha.read(i-1) * (1 - mask[:, i, None]) return log_alpha.write(i, log_alpha_i) for i in tf.range(1, seq_len): log_alpha loop_fn(i, log_alpha) log_alpha_final log_alpha.read(seq_len-1) log_z tf.reduce_logsumexp(log_alpha_final self.transitions[:, -1], axis-1) return tf.reduce_sum(log_z) def _viterbi_decode(self, emissions, sequence_lengths): batch_size, seq_len, num_tags tf.unstack(tf.shape(emissions)) viterbi tf.TensorArray(tf.float32, sizeseq_len) init_viterbi tf.fill([batch_size, num_tags], -1e4) init_viterbi tf.tensor_scatter_nd_update(init_viterbi, [[i, 0] for i in range(batch_size)], tf.zeros(batch_size)) viterbi viterbi.write(0, init_viterbi) backpointers tf.TensorArray(tf.int32, sizeseq_len) init_ptrs tf.fill([batch_size, num_tags], 0) backpointers backpointers.write(0, init_ptrs) emissions_t tf.transpose(emissions, perm[1, 0, 2]) mask tf.sequence_mask(sequence_lengths, maxlenseq_len, dtypetf.float32) def loop_fn(i, viterbi, backpointers): prev_viterbi viterbi.read(i-1) curr_emissions emissions_t[i] curr_viterbi prev_viterbi[:, None] self.transitions[None, :, :] curr_viterbi curr_emissions[:, None, :] max_logp tf.reduce_max(curr_viterbi, axis-1) argmax_tags tf.argmax(curr_viterbi, axis-1, output_typetf.int32) masked_max_logp max_logp * mask[:, i, None] prev_viterbi * (1 - mask[:, i, None]) masked_argmax_tags argmax_tags * tf.cast(mask[:, i, None], tf.int32) backpointers.read(i-1) * (1 - tf.cast(mask[:, i, None], tf.int32)) viterbi viterbi.write(i, masked_max_logp) backpointers backpointers.write(i, masked_argmax_tags) return viterbi, backpointers for i in tf.range(1, seq_len): viterbi, backpointers loop_fn(i, viterbi, backpointers) def get_best_path(i, best_tags, backpointers_t): best_tags tf.concat([tf.expand_dims(backpointers_t[i, tf.range(batch_size), best_tags[:, 0]], 1), best_tags], axis1) return i-1, best_tags, backpointers_t backpointers_t tf.transpose(backpointers.stack(), perm[1, 0, 2]) best_tags tf.expand_dims(tf.argmax(viterbi.read(seq_len-1), axis-1, output_typetf.int32), 1) _, best_tags, _ tf.while_loop( lambda i, *_: i 1, get_best_path, (seq_len-1, best_tags, backpointers_t) ) return best_tags