LLM 训练能不能少跑一点?Nous Research 的 TST 方法 大模型预训练的开销非常高这已经不是新鲜事。随着模型规模的不断扩大训练数据需求会持续增加训练周期和算力成本也越来越难以忽视。因此过去一两年LLM 研发团队一直在尝试提升 LLM 预训练的效率。相关方法大致可以分为三类从模型结构入手比如 MoE 和稀疏注意力改进输入表示比如 tokenizer 和 SuperBPE调整训练目标比如多 token 预测 multi-token prediction让模型一次预测未来多个 token。这些方向都很重要但它们也带来了一个共同的问题当模型结构、输入方式、训练目标甚至推理形态都发生变化时我们很难单独判断性能提升究竟来自哪里。Nous Research 最近这篇论文「Efficient Pre-Training with Token Superposition」选择了一个更克制的切入点能不能只在训练阶段动手但不改变最终模型论文概述在论文「Efficient Pre-Training with Token Superposition」中Nous Research 提出的方法叫 Token-Superposition Training简称 TST。论文摘要里给出的定位很明确TST 可以直接接入现有预训练流程不需要改变并行策略、优化器、分词器、训练数据或模型架构却能提高预训练阶段单位计算量下的数据吞吐效率。上图展示了在 10B-A1B MoE 模型上TST 和对照组 baseline 在训练损失 loss 曲线上的对比。论文中这一实验显示在达到相同训练损失的前提下TST 最多可以让总预训练时间减少约 2.5 倍。TST 的核心思路其实不复杂。标准 LLM 预训练是模型读入一串 token然后预测下一个 token。而 TST 在训练前期做了一件事把连续多个 token 合成一个 bag。比如原来模型要依次处理t1, t2, t3, t4, t5, t6, t7, t8如果 bag size 是 4它就会变成[t1, t2, t3, t4], [t5, t6, t7, t8]在输入侧TST 会把一个 bag 里多个 token 的向量表示token embedding做平均形成一个更粗粒度的叠加 token。在输出侧模型也不再只预测下一个 token而是预测下一个 bag 里可能会出现哪些 token。这样一来Transformer 实际处理的序列长度变短了但背后对应的原始文本 token 更多。论文把这个阶段叫做 token 叠加阶段 superposition phase。不过这种训练方式不能一直用下去。因为 bag 里 token 的顺序信息被抹掉了模型在这个阶段学到的是一种更粗粒度的未来 token 分布。如果只用这个目标训练模型不能直接变成一个正常可用的自回归语言模型。论文方法部分也提到只用 TST 训练的模型会产生混合的未来 token 概率分布推理输出会变得不正常。所以TST 将训练分为两阶段前期通过 token 叠加提高训练吞吐后期再切回标准的下一个 token 预测让模型恢复成普通的自回归语言模型。它真正有意思的地方也在这里训练过程可以临时变化但最终模型仍然是普通 LLM。下面我们就论文的部分内容展开详细的讲解预训练效率论文的开篇作者将现有的预训练效率方法大致分成三类第一类是information maximization也就是提高每个样本携带的信息密度。比如更好的 tokenization、SuperBPE、n-gram hashing以及多 token 预测等更丰富的训练信号。第二类是compute sparsity也就是减少每个 token 需要的计算量。典型例子是 MoE 和稀疏注意力。第三类是compressive modeling先在模型内部压缩 token 表示减少需要经过昂贵模型层的向量数量从而降低训练成本。但上面这些方法会影响训练时效率和推理时行为。TST 试图把问题进行拆解只关注预训练阶段的效率同时尽量不改变最终模型架构和推理方式。这也是 TST 和很多效率方法不太一样的地方。它不是为了让模型推理更快也不是为了让最终模型变成一种新的结构。它只是在训练前期改变 token 的组织方式让模型先以更粗粒度读文本等训练进入后期再回到标准语言建模。TST 的特别之处TST 不是凭空冒出来的它和现在已有的几个优化方向都有关系但又不完全一样。首先它和 tokenizer、SuperBPE 这类方法有关系。这些方法本身会改变模型看到文本的粒度。更粗的 token 粒度往往意味着同样长度的序列可以承载更多原始文本。论文中提到相比 byte-level 这种更细粒度的文本表示subword 分词能用更短的序列承载同样的文本内容因此样本吞吐效率会更高。但 TST 不改最终 tokenizer。它只是在训练阶段临时把多个 token 合成一个 bag后面仍然回到原来的 token 粒度。其次它和多 token 预测 multi-token prediction简称MTP也有相似处。MTP 的思路是不要只预测下一个 token而是同时预测未来多个 token。这样每一步训练可以获得更多监督信号。TST 的输出侧也在利用未来多个 token 的信息但它不只是“多预测几个 token”。它还会在输入侧把多个 token 的向量表示合并成一个更粗粒度的“叠加 token”让模型在训练时以更大的粒度处理文本。论文在 Discussion 部分专门解释了这一点MTP 及其变体通常不会提高训练阶段的吞吐效率模型每消耗同样的计算量处理的 token 数和对照组基本一样只是额外增加了辅助预测头和损失函数。TST 的不同之处在于它的目标就是提高训练阶段单位计算量能处理的 token 数同时不改变推理时的模型架构。所以更准确地说TST 不是在和 MTP 抢同一个位置。作者把它看作和 auxiliary-loss methods 正交未来甚至可以组合。TST 怎么做论文方法部分提到和标准的下一个 token 预测相比TST 主要有两个改动。第一个改动发生在输入侧。TST 把连续的 token 切成不重叠的 s-grams也就是长度为 s 的 token bag。然后它会对这个 bag 里所有 token 的向量表示取平均得到一个更粗粒度的表示论文里称为 s-token。如果原始序列长度是 Lbag size 是 s那么模型实际处理的序列长度就会变成 L / s。这意味着进入 Transformer 层计算的序列变短了模型每一步需要处理的表示数量也变少了。但因为每个 s-token 背后对应多个原始 token所以在训练时单位计算可以覆盖更多文本。第二个改动发生在输出侧。标准语言模型在每个位置预测下一个 tokenTST 则让每个 s-token 去预测下一个 bag-of-tokens。也就是说它预测的不是“下一个 token 是什么”而是“下一组 token 里会有哪些 token”。为了实现这个目标论文使用的是 multi-hot 交叉熵。这里的 multi-hot 可以理解为一个位置不只对应一个正确答案而是同时对应多个正确 token。普通交叉熵通常是一个位置对应一个正确 token而 multi-hot 交叉熵允许一个位置对应多个正确 token。这样模型可以在同一个位置同时接收来自多个未来 token 的训练信号。TST 方法示意图上图对比了几种不同的训练方式标准的下一个 token 预测 next token prediction、多 token 预测、SuperBPE 和 Token Superposition。TST 的特点是输入侧先把多个连续 token 的向量表示平均成一个 s-token输出侧再让模型预测下一组 token 中会出现哪些 token。这里还有两个关键超参数一个是bag size s也就是每几个 token 合成一个 bag。另一个是step ratio r也就是总训练步数里有多少比例使用 TST。比如 s8、r0.3意思就是训练前 30% 的 step 使用 TST每 8 个 token 合成一个 bag后面 70% 的 step 切回标准下一个 token 预测。这也是 TST 的核心结构训练前期粗粒度、高吞吐 训练后期标准粒度、恢复自回归建模标准训练的必要性为什么后面必须切回标准训练这是理解 TST 很关键的一点。TST 阶段会把多个 token 合成一个 bag因此模型学到的是“下一组 token 里会出现哪些 token”而不是严格意义上的“下一个 token 是什么”。这样做可以提高训练吞吐但也会带来一个问题bag 内部的 token 顺序信息会被弱化模型不能直接像普通语言模型那样逐 token 生成文本。论文方法部分把这个阶段称为 “半因果、半自回归”semi-causal, semi-autoregressive。也就是说模型整体上仍然沿着从左到右的方向学习后续内容但因为预测目标变成了 token bag它还不是一个标准的自回归语言模型。所以TST 必须有一个后续的恢复阶段 recovery phase先用 TST 做一段粗粒度训练再切回标准的下一个 token 预测让模型恢复正常的语言生成能力。这也说明TST 不是要抛弃下一个 token 预测而是把预训练拆成两段前期用粗粒度目标提高吞吐后期用标准目标恢复生成能力。实验结果论文的实验部分主要在验证两件事第一TST 在不同规模模型上是否稳定第二TST 的收益应该放在什么比较口径下理解——它是相同算力下更有效还是相同数据下也更有效。论文在 270M、600M 稠密模型上做了较多实验用来探索不同 bag size 和 step ratio 的效果然后又在 3B 稠密模型和 10B-A1B MoE 模型上做了验证。所有训练都使用 TorchTitan较大模型在 64 张 NVIDIA B200 GPU 上运行小模型在 8 张 B200 GPU 上运行。小模型数据从小模型实验看TST 对超参数不是完全不敏感但在一定范围内比较稳。论文结论部分给出的经验区间是bag size 大致在 4 到 8step ratio 大致在 0.2 到 0.4。上面一张图主要是展示不同 superposition bag size 和 step ratio 下的训练损失表现下面一张图主要展示对应的下游评测平均结果。两张图一起说明了 TST 不是随便把 bag size 调大就行参数需要落在合适范围内。更值得关注的是大模型实验。大模型数据在 10B-A1B MoE 模型上对照组训练了 1.05 万亿个 token而 TST 组训练了 2 万亿个数据 token。本文图 1 说明了两组实验中每个训练 step 的计算量相同因此可以直接根据训练步数计算加速比。论文摘要部分也提到在达到相同训练损失的前提下10B-A1B 规模上TST 最多可以实现约 2.5 倍训练加速也就是把总预训练时间缩短约 60%。表1 给出了更具体的结果数据在 10B-A1B 模型上baseline 对照组消耗了 12,311 个 B200 GPU 小时最终训练损失为 2.252TST 组消耗了 4,768 个 B200 GPU 小时最终训练损失为 2.236。下游评测中TST 在 HellaSwag、ARC-E、ARC-C 和 MMLU 上的得分也都高于对照组。不过这个结果不能简单地理解为“训练成本直接免费降了 2.5 倍”。因为 TST 的关键前提是它会更快消耗数据。TST 不是银弹TST 的核心取舍可以概括成一句话用更高的数据吞吐换更短的训练时间。它不是在相同数据消耗下白捡收益。相反在 TST 阶段模型每一步背后对应更多原始 token所以它会更快吃掉训练语料。这意味着TST 是否划算取决于你真正稀缺的是什么。如果训练主要卡在算力上而高质量数据还比较充足那么 TST 可能有用武之地。它可以让模型在相同计算预算下处理更多原始 token或者用更少训练时间达到相同的训练损失。但如果真正稀缺的是高质量数据TST 的优势就会减弱。论文在局限性部分也明确提到TST 本质上是在固定计算成本下用更多数据消耗换取更低的训练损失。也就是说它成立的前提是 LLM 预训练更偏向算力受限而不是数据受限。图注equal-FLOPs / equal-loss / equal-data 对比上图展示了 baseline 和 TST 在三种比较口径下的差异相同计算量、相同训练损失、相同数据量。TST 的优势主要体现在 equal-FLOPs 或 equal-loss 场景如果比较相同 token 消耗它不一定占优。关键分析除了主实验Discussion 里还有几组分析值得单独看。第一组分析是TST 不只是一个单点技巧。作者做了 input-only、output-only 和 full superposition 三组消融实验。结果显示只做输入侧叠加或者只做输出侧叠加都能超过对照组但输入侧和输出侧同时使用也就是完整的 TST效果最好。这说明 TST 至少包含两个相对独立的机制输入侧通过改变输入粒度降低单位信息经过模型计算的成本输出侧则通过改变预测目标让模型获得来自未来多个 token 的训练信号。这一点很重要。因为只看表面TST 很容易被理解成“把几个 token 的向量表示平均一下”。但论文的消融实验说明它并不只是输入压缩。输出侧的 next bag-of-tokens prediction也在为模型提供额外的训练信号。图注Input / Output Superposition 消融第二组分析是两个训练阶段之间的表示连续性很重要。TST 能够从前期的 token 叠加阶段切回后期的恢复阶段一个关键原因是两个阶段共享同一套输入向量表示和输出层也就是论文中提到的 input embedding 和 output LM head。作者做了一个验证实验在恢复阶段开始前随机重新初始化 input embedding 和 output LM head。结果显示TST 的收益会消失表现甚至比对照组更差。表 2 中3B 对照组的最终训练损失是 2.808正常 TST 是 2.676而重置 input embedding 和 output LM head 之后TST 的最终训练损失变成了 2.938。这说明TST 能成功不只是因为模型在前期处理了更多原始 token。更关键的是前期 token 叠加训练学到的表示能够延续到后期的标准下一个 token 预测中。如果两个阶段之间的表示无法衔接前期训练带来的收益就很难保留下来。局限性TST 的限制也比较明确。第一它依赖数据充足这一前提。如果未来高质量文本数据变得更稀缺而不是算力更稀缺那么 TST 这种更快消耗数据的方法可能会遇到问题。论文也提到在这种情况下output-only superposition 可能更值得研究因为它不增加数据消耗但作者把这部分留给未来工作。第二它可能对长上下文有潜在影响但论文没有评估。由于 TST 会把原始序列折叠成 token bag相当于在 TST 阶段拥有更长的有效上下文。作者推测这种机制可能有助于提升模型的长上下文能力但本文没有进一步做相关实验。第三实验规模还不能直接外推到所有大模型训练。论文已经验证到 3B 稠密模型和 10B-A1B MoE但作者也承认受算力限制他们没有做更大规模的消融实验也没有做多次相同实验来评估统计显著性。未来还需要研究 token superposition 的 scaling law才能判断更大模型上的最佳设置。第四作者提出了一些可能的解释。比如输入侧的 token 叠加可能类似一种 “由粗到细”coarse-to-fine 的训练过程模型先接触更粗粒度、更简单的文本统计结构再切回完整分辨率的语言建模。但论文也明确提到目前还缺少足够的可解释性证据无法判断 TST 具体是通过哪些机制带来收益的。因此它现阶段更像是一个有效的经验方法而不是一个机制已经被充分解释清楚的理论结论。小结我觉得这篇论文最值得关注的地方不只是“训练快了多少”而是它把预训练效率问题拆得更干净了训练阶段可以临时改变任务但最终模型不一定要变。过去我们谈训练效率常常会想到改模型结构、改分词器、改注意力机制或改推理方式。TST 提供了另一种思路训练前期先用更粗粒度的 token 叠加让模型以更高吞吐处理文本训练后期再切回标准的下一个 token 预测把模型恢复成正常的自回归语言模型。当然TST 还不是一个可以直接套用到所有训练任务上的通用答案。它依赖数据是否充足也需要更大规模实验和 scaling law 支撑。但它至少提供了一个值得继续验证的方向预训练不一定从头到尾都要使用同一种粒度、同一个训练目标。