JAX 深度学习框架核心机制深度解析:从函数变换到自动并行化的编译优化原理前言核心痛点:本文解决业界对 JAX 框架底层机制的深度理解需求——多数 AI 工程师熟悉 PyTorch 的即时执行模式,但对 JAX 的函数式变换哲学、JIT 编译流水线、自动并行化机制缺乏系统性认知,导致在选型时无法客观评估两套技术栈的优劣,或在迁移到 JAX 生态时遭遇"思维范式墙"。适配人群:具备 PyTorch/TensorFlow 使用经验的中高级 AI 工程师、深度学习框架开发者、对编译器优化感兴趣的系统工程师、正在评估 JAX 技术栈的架构师。收获能力:读完可掌握 JAX 函数变换体系(jit/grad/vmap/pmap)的底层原理 + XLA 编译优化全链路 + SPMD 自动并行化机制 + 生产级分布式训练落地实战能力。目录1. 技术背景与演进逻辑2. 核心原理深度解析3. 函数变换体系:JAX 的四大基石4. XLA 编译流水线与 Jaxpr 中间表示5. 分布式并行化架构6. JAX 生态体系全景7. JAX vs PyTorch 技术对比8. 技术优缺点与适用场景9. 实战落地10. 全文总结11. 系列说明12. 参考资料1. 技术背景与演进逻辑1.1 JAX 的诞生背景2018 年,Google Brain 团队发布了一篇名为《JAX: composable transformations of Python+NumPy programs》的技术报告,正式向社区推出 JAX 框架。彼时,深度学习框架的竞争格局已经明朗:TensorFlow 凭借静态图 + 工业级部署能力占据生产环境主导地位,PyTorch 以动态图 + Pythonic 编程体验迅速赢得研究社区的青睐。然而,这两个主流框架在设计哲学上都存在各自的妥协。TensorFlow 1.x 的静态图虽然能进行全图优化,但session.run()的编程模型割裂了 Python 控制流与计算图构建,调试体验极为痛苦。PyTorch 的即时执行(eager execution)虽然调试友好,但运算逐条下发到设备执行,缺少跨操作的全局优化空间——即便后来的torch.compile通过TorchDynamo捕获子图进行部分编译,其优化深度仍受限于 Python 解释器的"图断裂(graph break)"问题。JAX 的创始团队看到了第三条路:将 NumPy 的易用性、函数式编程的可组合性、编译器优化的极致性能三者融合。他们选择的核心理念是:不是构建一个新的深度学习框架,而是构建一个通用的数值计算编译器,深度学习只是它的一个应用场景。这一理念体现在 JAX 的设计取舍中:设计维度PyTorchTensorFlow 2.xJAX执行模型即时执行 + 选择性编译即时执行 +tf.function默认即时执行 +jit编译自动微分动态计算图(tape-based)动态计算图(tape-based)函数变换(源码级变换)中间表示TorchDynamo → FX Graph → InductorGrappler → MLIR → XLATracing → Jaxpr → StableHLO → HLO并行模型DDP / FSDP(手动配置)tf.distribute(策略模式)jit+ sharding(编译器自动决策)数组语义可变(mutable)可变(mutable)不可变(immutable)随机数全局状态全局状态显式 Key(无状态)函数变换不支持不支持一等公民(jit/grad/vmap/pmap 任意组合)JAX 目前的最新稳定版本是v0.6.0(2026 年 6 月),底层编译器已从 XLA 迁移至OpenXLA社区开源项目,实现了与 TensorFlow、PyTorch(通过torch_xla)共享编译器基础设施。1.2 传统框架的核心局限要理解 JAX 为什么以"函数变换"作为核心范式,需要先审视传统框架在以下场景中的局限:局限一:自动微分的扩展性瓶颈。PyTorch 的autograd基于动态计算图,每次前向传播都会构建一张新的计算图,反向传播完成后销毁。这个模型对于简单的前馈网络足够高效,但当需要计算高阶导数(如 Hessian 矩阵)、梯度的梯度(meta-learning)、或需要对同一函数多次求导(如物理信息神经网络 PINN)时,动态图的"一次性"特质导致代码复杂度和内存开销急剧膨胀。局限二:手动批处理的工程负担。研究者从单样本调试转向批量训练时,需要手动重写代码——加 batch 维度、调整矩阵乘法维度、处理 broadcasting 语义。torch.vmap虽然已加入 PyTorch,但其实验性质和使用限制(不支持所有算子)使得自动向量化仍未成为 PyTorch 的核心工作流。局限三:分布式训练的配置熵。PyTorch 的分布式训练需要研究者显式管理设备拓扑、通信策略、参数分片方式。从DataParallel到DistributedDataParallel再到FullyShardedDataParallel,API 层层嵌套。到了 Tensor Parallelism + Pipeline Parallelism + Data Parallelism 的 3D 并行阶段,配置复杂度呈指数增长。局限四:Python 的执行开销。在即时执行模式下,每个操作都涉及 Python → C++ → CUDA 的多层调用,对于细粒度操作(如自定义激活函数中的逐元素计算),Python 解释器开销可能超过实际数值计算时间。JAX 的 JIT 编译直接消除这一开销。1.3 JAX 的技术演进路线JAX 的发展可以分为四个阶段:[JAX 2018 发布] ↓ [第一阶段: 函数变换核心] ├── jit / grad / vmap / pmap 四大变换趋于稳定 ├── jax.numpy 覆盖 NumPy API 的 90%+ └── TPU 支持使得 Google 内部大规模采用 ↓ [第二阶段: 分布式架构重构] ├── pjit 引入 SPMD 编程模型 ├── GDA (Global Device Array) / jax.Array 统一多设备数据抽象 ├── NamedSharding / PartitionSpec 声明式分片方案 └── shard_map 提供 SPMD 手动控制 ↓ [第三阶段: 生态整合] ├── OpenXLA 社区接管编译器维护 ├── JAX ↔ PyTorch 互操作(jax2torch / torch2jax) ├── Flax / Haiku / Equinox / NNX 等 NN 库百花齐放 └── Orbax 统一检查点格式 ↓ [第四阶段: 生产就绪(当前)] ├── Google DeepMind Gemini/AlphaFold 全系基于 JAX ├── JAX on GPU/TPU/CPU 三平台成熟 ├── 多主机多切片(Multislice)训练支持 └── Pallas 自定义 Kernel 语言2. 核心原理深度解析2.1 函数式编程核心理念JAX 的本质是一个函数变换系统。它的设计遵循以下数学直觉:设有一个纯函数f: X → Y,JAX 提供的每一个变换都是一个高阶函数(Higher-Order Function),即输入一个函数、输出一个新函数:jit(f): X → Y—— 将普通 Python 函数编译为 XLA 优化的可执行代码grad(f): X → ∇f(X)—— 生成原函数的梯度函数vmap(f): X^batch → Y^batch—— 将单样本函数自动提升为批量函数pmap(f): X^replicated → Y^sharded—— 将函数分布到多个设备并行执行这四个变换的高阶函数特性意味着:它们可以任意组合。# 四重变换叠加:编译 + 向量化 + 自动微分 + 多设备并行@jit@vmap@graddefloss_fn(params,x,y):returnjnp.sum((predict(params,x)-y)**2)# 等价于手写:对每个样本计算梯度,然后 JIT 编译并在多设备上并行这种"变换可组合性"是 JAX 区别于一切传统框架的根本特征。PyTorch 的torch.compile、autograd、vmap也可以组合使用,但它们之间缺乏统一的形式化接口——torch.compile是一个 FX 图变换,autograd是一个上下文管理器,vmap是一个独立的函数包装器。JAX 中,这四个变换共享相同的调用签名和语义模型。2.2 不可变数组与纯函数约束JAX 的核心约束是:所有 JAX 变换只接受纯函数。纯函数的定义:函数的输出仅依赖于其输入参数函数没有副作用(不修改全局状态、不打印、不读写文件)JAX 数组(jax.Array)是不可变的,这与 NumPy 形成根本性差异:importjax.numpyasjnpimportnumpyasnp# NumPy:就地修改x=np.array([1,2,3])x[0]=10# 成功,x[0] 现在是 10# JAX:不可变x=jnp.array([1,2,3])x=x.at[0].set(10)# 返回新数组,原数组 x 不变不可变性的设计依据在于编译器的需求。XLA 编译器需要确定性地推理数据流——如果数组可以在任意位置被修改,编译器就无法安全地进行算子融合、内存复用、缓冲区别名分析等优化。JAX 选择以"无副作用 + 不可变"换取编译器的激进优化空间。2.3 Tracing 与 Jaxpr:JAX 的中间表示JAX 实现函数变换的核心机制是Tracing(追踪)。当调用jit(f)(x)时:[Python 函数 f] ↓ 传入抽象追踪器(Abstract Tracer)而非真实数组 [Tracing 阶段] —— 逐行追踪 f 的 Python 代码 │ 每个操作记录到计算图而非被执行 │ 追踪器携带 shape + dtype 但无具体数值 ↓ [Jaxpr 生成] —— JAX Program Representation │ jaxpr 是 JAX 的中间表示 │ 由简单的函数式原语(primitive)序列组成 ↓ [XLA 编译] —— jaxpr → StableHLO → HLO → 平台代码 │ ↓ [可执行文件] —— 直接跑在 GPU/TPU/CPU 上Jaxpr 是理解 JAX 内部运作最关键的抽象。它是一个小型的函数式 IR,仅包含以下元素:常量(ConstVar):编译时确定的字面量变量(Var):中间计算结果原语(Primitive):不可再分的底层操作(如add、dot_general、reduce_sum)等式(Equation):[out_vars] = primitive(input_vars; params)子调用(Subjaxpr):用于表示控制流(lax.cond、lax.scan等)下面是一个简单的 jaxpr 示例:importjaxdeff(x,y):returnjax.numpy.dot(x,y)+1.0# 查看 jaxprprint(jax.make_jaxpr(f)(jax.numpy.ones(3),jax.numpy.ones(3)))输出(简化表示):{ lambda ; a:f32[3] b:f32[3]. let c:f32[] = dot_general[dimension_numbers=(([0],[0]),([],[]))] a b d:f32[] = add c 1.0 in (d,) }这个 jaxpr 展示了追踪的核心价值:Python 控制流消失了,只剩下纯粹的数值操作序列。for循环被展开、if-elif-else被解析为cond原语、函数调用被内联。编译器看到的是一个没有 Python 语义干扰的纯计算图。3. 函数变换体系:JAX 的四大基石3.1jax.jit:即时编译jax.jit是 JAX 最核心的性能优化工具。它的工作原理:第一步:函数追踪。JIT 传入的函数的 Python 代码被逐行追踪,所有 JAX 操作被记录而非执行。追踪过程中,传入参数被替换为抽象值——只知道 shape 和 dtype,不知道具体数值。第二步:Jaxpr 生成。追踪结果被转换为 Jaxpr 中间表示。这个 IR 剥离了所有 Python 语义,只保留纯数值计算。第三步:XLA 编译。Jaxpr 被进一步降级为 StableHLO→HLO 表示,交由 XLA 编译器进行图优化——死代码消除、算子融合、代数简化、内存规划——最终生成平台特定的可执行代码(PTX for CUDA、VLIW for TPU)。第四步:缓存。编译结果按(函数签名, 参数 shape+dtype)缓存。相同签名的后续调用跳过编译,直接执行缓存的二进制代码。JIT 的限制:静态 Shape:所有数组的 shape 必须在编译时已知。x[x 0]返回的数组大小取决于数据内容,无法在编译时确定,会触发NonConcreteBooleanIndexError。无副作用:不能在 JIT 函数内使用print()(应使用jax.debug.print())、修改全局变量、或执行 I/O 操作。控制流必须用 JAX 语义:if语句必须替换为jax.lax.cond,for循环必须替换为jax.lax.fori_loop或jax.lax.scan。importjaximportjax.numpyasjnp@jax.jitdefselu(x,alpha=1.67326,lmbda=1.0507):returnlmbda*jnp.where(x0,x,alpha*jnp.exp(x)-alpha)# 第一次调用触发编译(warming-up),后续调用享受缓存x=jnp.arange(1000000.)%timeit selu(x).block_until_ready()# JIT 版本通常比非 JIT 版本快 10-100 倍(取决于算子粒度)3.2jax.grad:自动微分JAX 的自动微分基于**源码变换(Source Code Transformation)**而非 PyTorch 的计算图重放。当调用grad(f)时,JAX 执行以下步骤:追踪f生成前向 jaxpr对 jaxpr 中的每个原语(primitive),查找其**向量-雅可比乘积(VJP)**规则自动生成反向传播的 jaxpr返回一个计算梯度的新函数这意味着:JAX 的自动微分是"编译时"的——梯度计算代码由编译器自动生成,而非在运行时通过计算图重放。这一差异带来的关键优势:高阶微分极其自然。因为grad返回的也是一个 JAX 可追踪的纯函数,所以可以直接对梯度函数再次调用grad:deff(x):returnjnp.sum(x**3)df_dx=jax.grad(f)# 一阶导数: 3x^2d2f_dx2=jax.grad(df_dx)# 二阶导数: 6x → 等价于 jax.grad(jax.grad(f))d3f_dx3=jax.grad(d2f_dx2)
JAX 深度学习框架核心机制深度解析:从函数变换到自动并行化的编译优化原理
发布时间:2026/7/1 2:29:53
JAX 深度学习框架核心机制深度解析:从函数变换到自动并行化的编译优化原理前言核心痛点:本文解决业界对 JAX 框架底层机制的深度理解需求——多数 AI 工程师熟悉 PyTorch 的即时执行模式,但对 JAX 的函数式变换哲学、JIT 编译流水线、自动并行化机制缺乏系统性认知,导致在选型时无法客观评估两套技术栈的优劣,或在迁移到 JAX 生态时遭遇"思维范式墙"。适配人群:具备 PyTorch/TensorFlow 使用经验的中高级 AI 工程师、深度学习框架开发者、对编译器优化感兴趣的系统工程师、正在评估 JAX 技术栈的架构师。收获能力:读完可掌握 JAX 函数变换体系(jit/grad/vmap/pmap)的底层原理 + XLA 编译优化全链路 + SPMD 自动并行化机制 + 生产级分布式训练落地实战能力。目录1. 技术背景与演进逻辑2. 核心原理深度解析3. 函数变换体系:JAX 的四大基石4. XLA 编译流水线与 Jaxpr 中间表示5. 分布式并行化架构6. JAX 生态体系全景7. JAX vs PyTorch 技术对比8. 技术优缺点与适用场景9. 实战落地10. 全文总结11. 系列说明12. 参考资料1. 技术背景与演进逻辑1.1 JAX 的诞生背景2018 年,Google Brain 团队发布了一篇名为《JAX: composable transformations of Python+NumPy programs》的技术报告,正式向社区推出 JAX 框架。彼时,深度学习框架的竞争格局已经明朗:TensorFlow 凭借静态图 + 工业级部署能力占据生产环境主导地位,PyTorch 以动态图 + Pythonic 编程体验迅速赢得研究社区的青睐。然而,这两个主流框架在设计哲学上都存在各自的妥协。TensorFlow 1.x 的静态图虽然能进行全图优化,但session.run()的编程模型割裂了 Python 控制流与计算图构建,调试体验极为痛苦。PyTorch 的即时执行(eager execution)虽然调试友好,但运算逐条下发到设备执行,缺少跨操作的全局优化空间——即便后来的torch.compile通过TorchDynamo捕获子图进行部分编译,其优化深度仍受限于 Python 解释器的"图断裂(graph break)"问题。JAX 的创始团队看到了第三条路:将 NumPy 的易用性、函数式编程的可组合性、编译器优化的极致性能三者融合。他们选择的核心理念是:不是构建一个新的深度学习框架,而是构建一个通用的数值计算编译器,深度学习只是它的一个应用场景。这一理念体现在 JAX 的设计取舍中:设计维度PyTorchTensorFlow 2.xJAX执行模型即时执行 + 选择性编译即时执行 +tf.function默认即时执行 +jit编译自动微分动态计算图(tape-based)动态计算图(tape-based)函数变换(源码级变换)中间表示TorchDynamo → FX Graph → InductorGrappler → MLIR → XLATracing → Jaxpr → StableHLO → HLO并行模型DDP / FSDP(手动配置)tf.distribute(策略模式)jit+ sharding(编译器自动决策)数组语义可变(mutable)可变(mutable)不可变(immutable)随机数全局状态全局状态显式 Key(无状态)函数变换不支持不支持一等公民(jit/grad/vmap/pmap 任意组合)JAX 目前的最新稳定版本是v0.6.0(2026 年 6 月),底层编译器已从 XLA 迁移至OpenXLA社区开源项目,实现了与 TensorFlow、PyTorch(通过torch_xla)共享编译器基础设施。1.2 传统框架的核心局限要理解 JAX 为什么以"函数变换"作为核心范式,需要先审视传统框架在以下场景中的局限:局限一:自动微分的扩展性瓶颈。PyTorch 的autograd基于动态计算图,每次前向传播都会构建一张新的计算图,反向传播完成后销毁。这个模型对于简单的前馈网络足够高效,但当需要计算高阶导数(如 Hessian 矩阵)、梯度的梯度(meta-learning)、或需要对同一函数多次求导(如物理信息神经网络 PINN)时,动态图的"一次性"特质导致代码复杂度和内存开销急剧膨胀。局限二:手动批处理的工程负担。研究者从单样本调试转向批量训练时,需要手动重写代码——加 batch 维度、调整矩阵乘法维度、处理 broadcasting 语义。torch.vmap虽然已加入 PyTorch,但其实验性质和使用限制(不支持所有算子)使得自动向量化仍未成为 PyTorch 的核心工作流。局限三:分布式训练的配置熵。PyTorch 的分布式训练需要研究者显式管理设备拓扑、通信策略、参数分片方式。从DataParallel到DistributedDataParallel再到FullyShardedDataParallel,API 层层嵌套。到了 Tensor Parallelism + Pipeline Parallelism + Data Parallelism 的 3D 并行阶段,配置复杂度呈指数增长。局限四:Python 的执行开销。在即时执行模式下,每个操作都涉及 Python → C++ → CUDA 的多层调用,对于细粒度操作(如自定义激活函数中的逐元素计算),Python 解释器开销可能超过实际数值计算时间。JAX 的 JIT 编译直接消除这一开销。1.3 JAX 的技术演进路线JAX 的发展可以分为四个阶段:[JAX 2018 发布] ↓ [第一阶段: 函数变换核心] ├── jit / grad / vmap / pmap 四大变换趋于稳定 ├── jax.numpy 覆盖 NumPy API 的 90%+ └── TPU 支持使得 Google 内部大规模采用 ↓ [第二阶段: 分布式架构重构] ├── pjit 引入 SPMD 编程模型 ├── GDA (Global Device Array) / jax.Array 统一多设备数据抽象 ├── NamedSharding / PartitionSpec 声明式分片方案 └── shard_map 提供 SPMD 手动控制 ↓ [第三阶段: 生态整合] ├── OpenXLA 社区接管编译器维护 ├── JAX ↔ PyTorch 互操作(jax2torch / torch2jax) ├── Flax / Haiku / Equinox / NNX 等 NN 库百花齐放 └── Orbax 统一检查点格式 ↓ [第四阶段: 生产就绪(当前)] ├── Google DeepMind Gemini/AlphaFold 全系基于 JAX ├── JAX on GPU/TPU/CPU 三平台成熟 ├── 多主机多切片(Multislice)训练支持 └── Pallas 自定义 Kernel 语言2. 核心原理深度解析2.1 函数式编程核心理念JAX 的本质是一个函数变换系统。它的设计遵循以下数学直觉:设有一个纯函数f: X → Y,JAX 提供的每一个变换都是一个高阶函数(Higher-Order Function),即输入一个函数、输出一个新函数:jit(f): X → Y—— 将普通 Python 函数编译为 XLA 优化的可执行代码grad(f): X → ∇f(X)—— 生成原函数的梯度函数vmap(f): X^batch → Y^batch—— 将单样本函数自动提升为批量函数pmap(f): X^replicated → Y^sharded—— 将函数分布到多个设备并行执行这四个变换的高阶函数特性意味着:它们可以任意组合。# 四重变换叠加:编译 + 向量化 + 自动微分 + 多设备并行@jit@vmap@graddefloss_fn(params,x,y):returnjnp.sum((predict(params,x)-y)**2)# 等价于手写:对每个样本计算梯度,然后 JIT 编译并在多设备上并行这种"变换可组合性"是 JAX 区别于一切传统框架的根本特征。PyTorch 的torch.compile、autograd、vmap也可以组合使用,但它们之间缺乏统一的形式化接口——torch.compile是一个 FX 图变换,autograd是一个上下文管理器,vmap是一个独立的函数包装器。JAX 中,这四个变换共享相同的调用签名和语义模型。2.2 不可变数组与纯函数约束JAX 的核心约束是:所有 JAX 变换只接受纯函数。纯函数的定义:函数的输出仅依赖于其输入参数函数没有副作用(不修改全局状态、不打印、不读写文件)JAX 数组(jax.Array)是不可变的,这与 NumPy 形成根本性差异:importjax.numpyasjnpimportnumpyasnp# NumPy:就地修改x=np.array([1,2,3])x[0]=10# 成功,x[0] 现在是 10# JAX:不可变x=jnp.array([1,2,3])x=x.at[0].set(10)# 返回新数组,原数组 x 不变不可变性的设计依据在于编译器的需求。XLA 编译器需要确定性地推理数据流——如果数组可以在任意位置被修改,编译器就无法安全地进行算子融合、内存复用、缓冲区别名分析等优化。JAX 选择以"无副作用 + 不可变"换取编译器的激进优化空间。2.3 Tracing 与 Jaxpr:JAX 的中间表示JAX 实现函数变换的核心机制是Tracing(追踪)。当调用jit(f)(x)时:[Python 函数 f] ↓ 传入抽象追踪器(Abstract Tracer)而非真实数组 [Tracing 阶段] —— 逐行追踪 f 的 Python 代码 │ 每个操作记录到计算图而非被执行 │ 追踪器携带 shape + dtype 但无具体数值 ↓ [Jaxpr 生成] —— JAX Program Representation │ jaxpr 是 JAX 的中间表示 │ 由简单的函数式原语(primitive)序列组成 ↓ [XLA 编译] —— jaxpr → StableHLO → HLO → 平台代码 │ ↓ [可执行文件] —— 直接跑在 GPU/TPU/CPU 上Jaxpr 是理解 JAX 内部运作最关键的抽象。它是一个小型的函数式 IR,仅包含以下元素:常量(ConstVar):编译时确定的字面量变量(Var):中间计算结果原语(Primitive):不可再分的底层操作(如add、dot_general、reduce_sum)等式(Equation):[out_vars] = primitive(input_vars; params)子调用(Subjaxpr):用于表示控制流(lax.cond、lax.scan等)下面是一个简单的 jaxpr 示例:importjaxdeff(x,y):returnjax.numpy.dot(x,y)+1.0# 查看 jaxprprint(jax.make_jaxpr(f)(jax.numpy.ones(3),jax.numpy.ones(3)))输出(简化表示):{ lambda ; a:f32[3] b:f32[3]. let c:f32[] = dot_general[dimension_numbers=(([0],[0]),([],[]))] a b d:f32[] = add c 1.0 in (d,) }这个 jaxpr 展示了追踪的核心价值:Python 控制流消失了,只剩下纯粹的数值操作序列。for循环被展开、if-elif-else被解析为cond原语、函数调用被内联。编译器看到的是一个没有 Python 语义干扰的纯计算图。3. 函数变换体系:JAX 的四大基石3.1jax.jit:即时编译jax.jit是 JAX 最核心的性能优化工具。它的工作原理:第一步:函数追踪。JIT 传入的函数的 Python 代码被逐行追踪,所有 JAX 操作被记录而非执行。追踪过程中,传入参数被替换为抽象值——只知道 shape 和 dtype,不知道具体数值。第二步:Jaxpr 生成。追踪结果被转换为 Jaxpr 中间表示。这个 IR 剥离了所有 Python 语义,只保留纯数值计算。第三步:XLA 编译。Jaxpr 被进一步降级为 StableHLO→HLO 表示,交由 XLA 编译器进行图优化——死代码消除、算子融合、代数简化、内存规划——最终生成平台特定的可执行代码(PTX for CUDA、VLIW for TPU)。第四步:缓存。编译结果按(函数签名, 参数 shape+dtype)缓存。相同签名的后续调用跳过编译,直接执行缓存的二进制代码。JIT 的限制:静态 Shape:所有数组的 shape 必须在编译时已知。x[x 0]返回的数组大小取决于数据内容,无法在编译时确定,会触发NonConcreteBooleanIndexError。无副作用:不能在 JIT 函数内使用print()(应使用jax.debug.print())、修改全局变量、或执行 I/O 操作。控制流必须用 JAX 语义:if语句必须替换为jax.lax.cond,for循环必须替换为jax.lax.fori_loop或jax.lax.scan。importjaximportjax.numpyasjnp@jax.jitdefselu(x,alpha=1.67326,lmbda=1.0507):returnlmbda*jnp.where(x0,x,alpha*jnp.exp(x)-alpha)# 第一次调用触发编译(warming-up),后续调用享受缓存x=jnp.arange(1000000.)%timeit selu(x).block_until_ready()# JIT 版本通常比非 JIT 版本快 10-100 倍(取决于算子粒度)3.2jax.grad:自动微分JAX 的自动微分基于**源码变换(Source Code Transformation)**而非 PyTorch 的计算图重放。当调用grad(f)时,JAX 执行以下步骤:追踪f生成前向 jaxpr对 jaxpr 中的每个原语(primitive),查找其**向量-雅可比乘积(VJP)**规则自动生成反向传播的 jaxpr返回一个计算梯度的新函数这意味着:JAX 的自动微分是"编译时"的——梯度计算代码由编译器自动生成,而非在运行时通过计算图重放。这一差异带来的关键优势:高阶微分极其自然。因为grad返回的也是一个 JAX 可追踪的纯函数,所以可以直接对梯度函数再次调用grad:deff(x):returnjnp.sum(x**3)df_dx=jax.grad(f)# 一阶导数: 3x^2d2f_dx2=jax.grad(df_dx)# 二阶导数: 6x → 等价于 jax.grad(jax.grad(f))d3f_dx3=jax.grad(d2f_dx2)