0x00 摘要subgraph_extractor.py 是 KernelFalcon 实现 “PyTorch 模型子图提取 形状签名去重” 的关键组件核心职责是通过 Fuser 生成融合代码后借助 LLM 解析并提取模型中唯一的计算子图按形状 / 算子 / 权重特征去重最终输出标准化 JSON 格式的子图信息。这一模块体现了 “Agent 端到端优化” 中 “精准子图识别” 的关键能力。Extractor 的架构如下其三阶段总结如下融合重写调Orchestrator多worker并行让LLM把原PyTorch改写成可融合子模块沙箱验证数值等价产出code.py.tgz.子图识别把原问题融合代码一起喂给LLM要求输出每个唯一子图的 JSON描述ops、shapes、weights、来源代码片段去重合并按opsshapesweightslayoutdtype算签名相同签名合并count确保Dispatcher 只为每个独特子图生成一次内核0x01 功能详解subgraph_extractor.py 是子图提取的关键组件用于从融合的 PyTorch 代码中识别和提取可融合的子图及其形状信息。这个组件使得复杂的融合操作可以分解为更小、更易管理的子图每个子图都有精确的形状和语义信息为后续的 Triton 内核生成奠定了坚实基础。1.1 核心作用subgraph_extractor.py 的核心作用如下分析融合代码从 Fuser 生成的融合 PyTorch 代码中提取语义信息识别可以独立优化的子图形状感知提取精确的输入 / 输出形状用于后续优化生成结构化描述创建精确的 JSON 格式子图描述去重机制基于形状签名消除重复子图为后续阶段提供输入为 Triton 内核生成和最终合成提供基础数据主要功能流程如下。问题文件 ↓ Fuser Orchestrator 生成融合代码 ↓ subgraph_extractor.py 分析融合代码并提取子图 ↓ subgraphs.json 结构化子图描述1.2 详细分析代码提取功能_load_code_from_tar 函数完成代码提取功能。def _load_code_from_tar(artifact_path: Path) - str: 从tar.gz压缩包中读取code.py文件内容Fuser生成的融合代码 # 检查压缩包文件是否存在不存在则返回空字符串 if not artifact_path.is_file(): return # 以只读模式打开gzip压缩的tar包 with tarfile.open(artifact_path, r:gz) as tf: try: # 获取压缩包中名为code.py的文件成员 member tf.getmember(code.py) except KeyError: # 若code.py不存在返回空字符串 return # 提取code.py文件内容 extracted tf.extractfile(member) # 若提取失败文件为空返回空字符串 if extracted is None: return # 读取文件内容并解码为UTF-8字符串返回 return extracted.read().decode(utf-8)LLM提示构建_build_llm_prompt_for_shapes 完成了 prompt构建功能其关键特点是精确性要求精确的形状签名结构化强制返回特定的 JSON 格式完整性包含操作、权重、布局等所有相关信息def _build_llm_prompt_for_shapes(fused_code: str, problem_code: str) - tuple[str, str]: 构建LLM提示词引导LLM分析融合代码和原始代码提取子图信息 # System Prompt强制要求仅返回JSON数组 system Return a single JSON array only. user_lines: list[str] [] # 角色与背景说明告知LLM输入内容原始问题代码融合代码 user_lines.append( You are given:\n- The original problem (PyTorch).\n- A fused refactor produced by Fuser (PyTorch subgraph modules). ) # 核心任务说明按形状签名识别唯一子图输出指定Schema的JSON数组 user_lines.append( Task: Identify every unique subgraph by exact shape signature and emit a JSON array matching this schema (and only this schema): ) # 详细Schema定义明确每个字段的含义和格式 user_lines.append( {\n id: string,\n type: string,\n data_layout: \\NCHW\\|\\NHWC\\|null,\n dtype: string|null,\n ops: [ {op: string, ... op-specific fields ... } ],\n input_shape: [int|sym, ...] // OR \\inputs\\: [[...], [...]] for multi-input\n output_shape: [int|sym, ...],\n weights_fused: { name: [int|sym, ...], ... } | null,\n weights_original: { name: [int|sym, ...], ... } | null,\n count: int,\n where: string,\n source: { module: string, code: string }\n } ) # 关键注意事项细化提取规则提升准确性 user_lines.append(Notes:) user_lines.append( - Treat any shape difference (inputs/outputs/weights) as a distinct subgraph. Count occurrences. ) user_lines.append( - Populate op-specific fields for conv/pool/linear, e.g., kernel_size/stride/padding/groups, bn_fused, output_size, start_dim. ) user_lines.append( - Include both weights_original (pre-fusion params like BN gamma/beta/running stats) and weights_fused (post-fusion conv/bias). Use null if not applicable. ) user_lines.append( - Provide a short \where\ string (e.g., Model.forward stem or layer2.block3.conv). ) user_lines.append( - Provide source with the smallest contiguous code snippet implementing the subgraph. ) user_lines.append( - Use data_layout and dtype when clear (default conv layout is NCHW). ) user_lines.append( - For binary ops like residual add, use inputs: [[...],[...]]. ) user_lines.append( - Prefer concrete integers from get_inputs() shapes in the problem; otherwise use symbols like B, H, W. ) user_lines.append() # 输入代码原始问题代码 user_lines.append(PROBLEM_FILE:\npython) user_lines.append(problem_code) user_lines.append() user_lines.append() # 输入代码Fuser生成的融合代码 user_lines.append(FUSED_CODE:) user_lines.append(python) user_lines.append(fused_code) user_lines.append() user_lines.append() # 最终要求仅返回包含数组的JSON代码块无其他文本 user_lines.append( Now return only one fenced JSON block containing the array. No prose. ) # 返回System Prompt和User Prompt return system, \n.join(user_lines)形状签名去重机制_dedup_by_shape_signature 实现了去重代码。基于输入 / 权重 / 输出形状的标准化表示忽略名称但保留维度和数据类型确保相同语义的子图被合并def _dedup_by_shape_signature(items: list[dict[str, Any]]) - list[dict[str, Any]]: Deduplicate items by a stable shape signature. The signature is based on sorted lists of input/weight/output shapes content, ignoring names but preserving dimensions and dtypes. 基于稳定的形状签名对子图列表去重 - 签名基于输入/权重/输出形状的标准化内容忽略名称保留维度和数据类型 - 保证相同形状特征的子图只保留一个 def norm_shapes(arr: Any) - Any: 内部函数标准化形状数组统一不同格式的形状描述 # 非列表类型直接返回空列表 if not isinstance(arr, list): return [] normed: list[Any] [] # 遍历数组中的每个元素 for e in arr: if isinstance(e, dict): # 标准化形状字典的键兼容不同命名方式shape/dims/size shape e.get(shape) or e.get(dims) or e.get(size) dtype e.get(dtype) kind e.get(kind) or e.get(role) # 标准化维度优先int/str类型其他类型转为字符串 if isinstance(shape, list): dims [str(x) for x in shape] elif isinstance(shape, (int, str)): dims [str(shape)] else: dims [str(shape)] if shape is not None else [] # 构建标准化的形状描述字典 normed.append( {dims: dims, dtype: str(dtype) if dtype else None, k: kind} ) else: # 非字典元素直接转为字符串 normed.append(str(e)) # 排序以保证签名的稳定性避免顺序不同导致签名不同 return sorted(normed, keylambda x: json.dumps(x, sort_keysTrue)) # 存储已见过的签名避免重复 seen: set[str] set() out: list[dict[str, Any]] [] # 遍历所有子图项 for it in items: # 构建签名对象包含输入/权重/输出的标准化形状 sig_obj { inputs: norm_shapes(it.get(input_shapes)), weights: norm_shapes(it.get(weight_shapes) or it.get(weights)), outputs: norm_shapes(it.get(output_shapes)), } # 转为JSON字符串作为唯一签名排序保证稳定性 sig json.dumps(sig_obj, sort_keysTrue) # 若签名未见过则保留该子图 if sig in seen: continue seen.add(sig) out.append(it) # 返回去重后的子图列表 return out1.3 流程图subgraph_extractor.py 的流程如下初始化阶段创建OrchestratorConfig配置对象生成唯一的运行ID并创建运行目录结构初始化Orchestrator对象代码提取阶段运行Orchestrator.run()获取融合后的PyTorch代码检查是否成功找到解决方案加载原始问题代码和融合后的代码LLM分析阶段构建包含原始问题和融合代码的提示根据提供商类型选择不同的API调用方式提取并解析LLM返回的JSON格式的子图描述后处理阶段验证JSON结构的有效性通过形状签名对子图进行去重和合并保存最终的subgraphs.json文件返回运行目录和JSON文件路径具体流程图如下1.4 与系统其他组件的交互与 Orchestrator 的交互调用 Fuser Orchestrator 生成融合代码orch Orchestrator(...) summary orch.run() fused_code _load_code_from_tar(Path(summary.artifact_path))与 Dispatch Kernel Agent 的交互生成的 subgraphs.json 作为 dispatch_kernel_agent.py 的输入为每个子图生成 Triton 内核。与 Composer 的交互subgraphs.json 作为 compose_end_to_end.py 的输入之一用于最终的端到端合成LLM 交互机制# Provider 选择 provider get_model_provider(model_name) if provider.name ! openai: # 直接调用提供商 result provider.get_response(...) else: # 通过 EventAdapter 流式处理 adapter EventAdapter(...) result adapter.stream( system_promptSYSTEM_PROMPT, user_promptrp.user, extrasrp.extras)0x02 Prompt我们来分析 subgraph_extractor.py 中的 Prompt 构建机制。2.1 概括subgraph_extractor.py 使用的是 LLM 提示专门用于从融合的 PyTorch 代码中提取子图及其形状信息。这条 prompt 可以一句话概括“把话说到编译器级别不给自由发挥留缝隙。”具体特点拆解如下极端结构化用 JSON Schema 把字段名、类型、取值范围、嵌套层级一次性钉死连null能出现在哪都标好。要求“只返回一个 fenced code block”直接把自然语言出口焊死防止模型“顺便聊聊”。双重代码上下文同时给出“原始 PyTorch 代码”和“融合后的代码”让模型既能看到“改名前的权重”也能看到“融合后的权重”相当于开卷考试但限定只能写标准答案格式。微观操作级说明书对每一类算子conv、pool、linear、add都列出必须出现的 keykernel_size/stride/padding/groups…把“该抄哪几行”写成 checklist模型只要漏一项就能被后处理脚本一键拒收。明确“形状不同就算新子图”避免模型把不同 block 的同名层合并。符号系统与优先级双重约束先拿get_inputs()的 concrete shape 当“硬数”找不到才允许用B/H/W符号既保证可静态检查又留一条退路。权重必须同时给weights_original和weights_fused逼模型把“融合前后张量对应关系”显式写出来防止“黑箱合并”。Zero-shot 但 Zero-creativity没有 few-shot 示例却用 12 条“Notes”把边界情况全部穷举等于告诉模型“你不需要创新只需要当一台会数数的扫描仪”。最后用“No prose”把寒暄、总结、解释统统 ban 掉输出直接变成可json.loads的“机器口粮”。2.2 Prompt 的基本结构System PromptSystem Prompt 的内容如下Return a single JSON array only.其要求 LLM 只返回单个 JSON 数组避免返回额外的文本说明。User Prompt 详细结构首先是背景介绍user_lines.append( You are given:\n- The original problem (PyTorch).\n- A fused refactor produced by Fuser (PyTorch subgraph modules). )其次是任务描述从融合代码中识别所有独特的子图提取精确的形状信息输入 / 输出 / 权重为后续的 Triton 内核生成提供结构化输入user_lines.append( Task: Identify every unique subgraph by exact shape signature and emit a JSON array matching this schema (and only this schema): )接下来会说明期望的JSON schemauser_lines.append( {\n id: string,\n type: string,\n data_layout: \\NCHW\\|\\NHWC\\|null,\n dtype: string|null,\n ops: [ {op: string, ... op-specific fields ... } ],\n input_shape: [int|sym, ...] // OR \\inputs\\: [[...], [...]] for multi-input\n output_shape: [int|sym, ...],\n weights_fused: { name: [int|sym, ...], ... } | null,\n weights_original: { name: [int|sym, ...], ... } | null,\n count: int,\n where: string,\n source: { module: string, code: string }\n } )然后是详细说明和注意事项user_lines.append(Notes:) user_lines.append( - Treat any shape difference (inputs/outputs/weights) as a distinct subgraph. Count occurrences. ) user_lines.append( - Populate op-specific fields for conv/pool/linear, e.g., kernel_size/stride/padding/groups, bn_fused, output_size, start_dim. ) user_lines.append( - Include both weights_original (pre-fusion params like BN gamma/beta/running stats) and weights_fused (post-fusion conv/bias). Use null if not applicable. ) user_lines.append( - Provide a short \where\ string (e.g., Model.forward stem or layer2.block3.conv). ) user_lines.append( - Provide source with the smallest contiguous code snippet implementing the subgraph. ) user_lines.append( - Use data_layout and dtype when clear (default conv layout is NCHW). ) user_lines.append( - For binary ops like residual add, use inputs: [[...],[...]]. ) user_lines.append( - Prefer concrete integers from get_inputs() shapes in the problem; otherwise use symbols like B, H, W. )最后是输入代码示例user_lines.append(PROBLEM_FILE:\npython) user_lines.append(problem_code) user_lines.append() user_lines.append() user_lines.append(FUSED_CODE:) user_lines.append(python) user_lines.append(fused_code) user_lines.append() user_lines.append() user_lines.append( Now return only one fenced JSON block containing the array. No prose. )2.3 使用时机在 extract_subgraphs_to_json 函数中会调用 prompt# Ask LLM for shapes JSON system, user _build_llm_prompt_for_shapes(fused_code, problem_code) Temporary MUX to support Relay while we migrate to OpenAI Responses API. Uses EventAdapter for OpenAI, otherwise Provider inferface provider get_model_provider(model_name) if provider.name ! openai: # 直接调用提供商 result provider.get_response(...) else: # 通过 EventAdapter 流式处理 adapter EventAdapter(...) result adapter.stream(...)0x03 实现subgraph_extractor.py 实现了 KernelFalcon 的 “子图识别” 核心能力 —— 通过 Fuser 生成融合代码→LLM 解析代码提取子图→签名去重合并→输出标准化 JSON为后续 Triton 算子自动生成提供精准的子图粒度输入。3.1 特色LLM 驱动的智能子图识别放弃传统的 “静态代码解析 规则匹配”改用 LLM 理解 PyTorch 代码语义精准识别卷积 / 池化 / 线性层等算子的子图边界、形状、权重特征适配复杂的融合代码场景鲁棒的签名去重机制基于 “算子 输入 / 输出形状 权重结构 数据布局 数据类型” 构建稳定签名避免因命名 / 格式差异导致的重复子图保证子图识别的唯一性全链路容错设计针对 LLM 输出格式异常、JSON 解析失败、代码文件缺失等场景均有明确的容错逻辑和诊断文件输出提升工业级可用性标准化输出格式定义统一的子图 JSON Schema包含 id、类型、形状、权重、计数等核心字段为后续算子生成和模型优化提供标准化输入适配多 LLM 提供商兼容 OpenAI Responses API 和其他 LLM 提供商的接口通过适配层统一调用逻辑保证灵
PyTorch KernelAgent 源码解读 ---(4)--- ExtractorAgent
发布时间:2026/6/28 10:52:35
0x00 摘要subgraph_extractor.py 是 KernelFalcon 实现 “PyTorch 模型子图提取 形状签名去重” 的关键组件核心职责是通过 Fuser 生成融合代码后借助 LLM 解析并提取模型中唯一的计算子图按形状 / 算子 / 权重特征去重最终输出标准化 JSON 格式的子图信息。这一模块体现了 “Agent 端到端优化” 中 “精准子图识别” 的关键能力。Extractor 的架构如下其三阶段总结如下融合重写调Orchestrator多worker并行让LLM把原PyTorch改写成可融合子模块沙箱验证数值等价产出code.py.tgz.子图识别把原问题融合代码一起喂给LLM要求输出每个唯一子图的 JSON描述ops、shapes、weights、来源代码片段去重合并按opsshapesweightslayoutdtype算签名相同签名合并count确保Dispatcher 只为每个独特子图生成一次内核0x01 功能详解subgraph_extractor.py 是子图提取的关键组件用于从融合的 PyTorch 代码中识别和提取可融合的子图及其形状信息。这个组件使得复杂的融合操作可以分解为更小、更易管理的子图每个子图都有精确的形状和语义信息为后续的 Triton 内核生成奠定了坚实基础。1.1 核心作用subgraph_extractor.py 的核心作用如下分析融合代码从 Fuser 生成的融合 PyTorch 代码中提取语义信息识别可以独立优化的子图形状感知提取精确的输入 / 输出形状用于后续优化生成结构化描述创建精确的 JSON 格式子图描述去重机制基于形状签名消除重复子图为后续阶段提供输入为 Triton 内核生成和最终合成提供基础数据主要功能流程如下。问题文件 ↓ Fuser Orchestrator 生成融合代码 ↓ subgraph_extractor.py 分析融合代码并提取子图 ↓ subgraphs.json 结构化子图描述1.2 详细分析代码提取功能_load_code_from_tar 函数完成代码提取功能。def _load_code_from_tar(artifact_path: Path) - str: 从tar.gz压缩包中读取code.py文件内容Fuser生成的融合代码 # 检查压缩包文件是否存在不存在则返回空字符串 if not artifact_path.is_file(): return # 以只读模式打开gzip压缩的tar包 with tarfile.open(artifact_path, r:gz) as tf: try: # 获取压缩包中名为code.py的文件成员 member tf.getmember(code.py) except KeyError: # 若code.py不存在返回空字符串 return # 提取code.py文件内容 extracted tf.extractfile(member) # 若提取失败文件为空返回空字符串 if extracted is None: return # 读取文件内容并解码为UTF-8字符串返回 return extracted.read().decode(utf-8)LLM提示构建_build_llm_prompt_for_shapes 完成了 prompt构建功能其关键特点是精确性要求精确的形状签名结构化强制返回特定的 JSON 格式完整性包含操作、权重、布局等所有相关信息def _build_llm_prompt_for_shapes(fused_code: str, problem_code: str) - tuple[str, str]: 构建LLM提示词引导LLM分析融合代码和原始代码提取子图信息 # System Prompt强制要求仅返回JSON数组 system Return a single JSON array only. user_lines: list[str] [] # 角色与背景说明告知LLM输入内容原始问题代码融合代码 user_lines.append( You are given:\n- The original problem (PyTorch).\n- A fused refactor produced by Fuser (PyTorch subgraph modules). ) # 核心任务说明按形状签名识别唯一子图输出指定Schema的JSON数组 user_lines.append( Task: Identify every unique subgraph by exact shape signature and emit a JSON array matching this schema (and only this schema): ) # 详细Schema定义明确每个字段的含义和格式 user_lines.append( {\n id: string,\n type: string,\n data_layout: \\NCHW\\|\\NHWC\\|null,\n dtype: string|null,\n ops: [ {op: string, ... op-specific fields ... } ],\n input_shape: [int|sym, ...] // OR \\inputs\\: [[...], [...]] for multi-input\n output_shape: [int|sym, ...],\n weights_fused: { name: [int|sym, ...], ... } | null,\n weights_original: { name: [int|sym, ...], ... } | null,\n count: int,\n where: string,\n source: { module: string, code: string }\n } ) # 关键注意事项细化提取规则提升准确性 user_lines.append(Notes:) user_lines.append( - Treat any shape difference (inputs/outputs/weights) as a distinct subgraph. Count occurrences. ) user_lines.append( - Populate op-specific fields for conv/pool/linear, e.g., kernel_size/stride/padding/groups, bn_fused, output_size, start_dim. ) user_lines.append( - Include both weights_original (pre-fusion params like BN gamma/beta/running stats) and weights_fused (post-fusion conv/bias). Use null if not applicable. ) user_lines.append( - Provide a short \where\ string (e.g., Model.forward stem or layer2.block3.conv). ) user_lines.append( - Provide source with the smallest contiguous code snippet implementing the subgraph. ) user_lines.append( - Use data_layout and dtype when clear (default conv layout is NCHW). ) user_lines.append( - For binary ops like residual add, use inputs: [[...],[...]]. ) user_lines.append( - Prefer concrete integers from get_inputs() shapes in the problem; otherwise use symbols like B, H, W. ) user_lines.append() # 输入代码原始问题代码 user_lines.append(PROBLEM_FILE:\npython) user_lines.append(problem_code) user_lines.append() user_lines.append() # 输入代码Fuser生成的融合代码 user_lines.append(FUSED_CODE:) user_lines.append(python) user_lines.append(fused_code) user_lines.append() user_lines.append() # 最终要求仅返回包含数组的JSON代码块无其他文本 user_lines.append( Now return only one fenced JSON block containing the array. No prose. ) # 返回System Prompt和User Prompt return system, \n.join(user_lines)形状签名去重机制_dedup_by_shape_signature 实现了去重代码。基于输入 / 权重 / 输出形状的标准化表示忽略名称但保留维度和数据类型确保相同语义的子图被合并def _dedup_by_shape_signature(items: list[dict[str, Any]]) - list[dict[str, Any]]: Deduplicate items by a stable shape signature. The signature is based on sorted lists of input/weight/output shapes content, ignoring names but preserving dimensions and dtypes. 基于稳定的形状签名对子图列表去重 - 签名基于输入/权重/输出形状的标准化内容忽略名称保留维度和数据类型 - 保证相同形状特征的子图只保留一个 def norm_shapes(arr: Any) - Any: 内部函数标准化形状数组统一不同格式的形状描述 # 非列表类型直接返回空列表 if not isinstance(arr, list): return [] normed: list[Any] [] # 遍历数组中的每个元素 for e in arr: if isinstance(e, dict): # 标准化形状字典的键兼容不同命名方式shape/dims/size shape e.get(shape) or e.get(dims) or e.get(size) dtype e.get(dtype) kind e.get(kind) or e.get(role) # 标准化维度优先int/str类型其他类型转为字符串 if isinstance(shape, list): dims [str(x) for x in shape] elif isinstance(shape, (int, str)): dims [str(shape)] else: dims [str(shape)] if shape is not None else [] # 构建标准化的形状描述字典 normed.append( {dims: dims, dtype: str(dtype) if dtype else None, k: kind} ) else: # 非字典元素直接转为字符串 normed.append(str(e)) # 排序以保证签名的稳定性避免顺序不同导致签名不同 return sorted(normed, keylambda x: json.dumps(x, sort_keysTrue)) # 存储已见过的签名避免重复 seen: set[str] set() out: list[dict[str, Any]] [] # 遍历所有子图项 for it in items: # 构建签名对象包含输入/权重/输出的标准化形状 sig_obj { inputs: norm_shapes(it.get(input_shapes)), weights: norm_shapes(it.get(weight_shapes) or it.get(weights)), outputs: norm_shapes(it.get(output_shapes)), } # 转为JSON字符串作为唯一签名排序保证稳定性 sig json.dumps(sig_obj, sort_keysTrue) # 若签名未见过则保留该子图 if sig in seen: continue seen.add(sig) out.append(it) # 返回去重后的子图列表 return out1.3 流程图subgraph_extractor.py 的流程如下初始化阶段创建OrchestratorConfig配置对象生成唯一的运行ID并创建运行目录结构初始化Orchestrator对象代码提取阶段运行Orchestrator.run()获取融合后的PyTorch代码检查是否成功找到解决方案加载原始问题代码和融合后的代码LLM分析阶段构建包含原始问题和融合代码的提示根据提供商类型选择不同的API调用方式提取并解析LLM返回的JSON格式的子图描述后处理阶段验证JSON结构的有效性通过形状签名对子图进行去重和合并保存最终的subgraphs.json文件返回运行目录和JSON文件路径具体流程图如下1.4 与系统其他组件的交互与 Orchestrator 的交互调用 Fuser Orchestrator 生成融合代码orch Orchestrator(...) summary orch.run() fused_code _load_code_from_tar(Path(summary.artifact_path))与 Dispatch Kernel Agent 的交互生成的 subgraphs.json 作为 dispatch_kernel_agent.py 的输入为每个子图生成 Triton 内核。与 Composer 的交互subgraphs.json 作为 compose_end_to_end.py 的输入之一用于最终的端到端合成LLM 交互机制# Provider 选择 provider get_model_provider(model_name) if provider.name ! openai: # 直接调用提供商 result provider.get_response(...) else: # 通过 EventAdapter 流式处理 adapter EventAdapter(...) result adapter.stream( system_promptSYSTEM_PROMPT, user_promptrp.user, extrasrp.extras)0x02 Prompt我们来分析 subgraph_extractor.py 中的 Prompt 构建机制。2.1 概括subgraph_extractor.py 使用的是 LLM 提示专门用于从融合的 PyTorch 代码中提取子图及其形状信息。这条 prompt 可以一句话概括“把话说到编译器级别不给自由发挥留缝隙。”具体特点拆解如下极端结构化用 JSON Schema 把字段名、类型、取值范围、嵌套层级一次性钉死连null能出现在哪都标好。要求“只返回一个 fenced code block”直接把自然语言出口焊死防止模型“顺便聊聊”。双重代码上下文同时给出“原始 PyTorch 代码”和“融合后的代码”让模型既能看到“改名前的权重”也能看到“融合后的权重”相当于开卷考试但限定只能写标准答案格式。微观操作级说明书对每一类算子conv、pool、linear、add都列出必须出现的 keykernel_size/stride/padding/groups…把“该抄哪几行”写成 checklist模型只要漏一项就能被后处理脚本一键拒收。明确“形状不同就算新子图”避免模型把不同 block 的同名层合并。符号系统与优先级双重约束先拿get_inputs()的 concrete shape 当“硬数”找不到才允许用B/H/W符号既保证可静态检查又留一条退路。权重必须同时给weights_original和weights_fused逼模型把“融合前后张量对应关系”显式写出来防止“黑箱合并”。Zero-shot 但 Zero-creativity没有 few-shot 示例却用 12 条“Notes”把边界情况全部穷举等于告诉模型“你不需要创新只需要当一台会数数的扫描仪”。最后用“No prose”把寒暄、总结、解释统统 ban 掉输出直接变成可json.loads的“机器口粮”。2.2 Prompt 的基本结构System PromptSystem Prompt 的内容如下Return a single JSON array only.其要求 LLM 只返回单个 JSON 数组避免返回额外的文本说明。User Prompt 详细结构首先是背景介绍user_lines.append( You are given:\n- The original problem (PyTorch).\n- A fused refactor produced by Fuser (PyTorch subgraph modules). )其次是任务描述从融合代码中识别所有独特的子图提取精确的形状信息输入 / 输出 / 权重为后续的 Triton 内核生成提供结构化输入user_lines.append( Task: Identify every unique subgraph by exact shape signature and emit a JSON array matching this schema (and only this schema): )接下来会说明期望的JSON schemauser_lines.append( {\n id: string,\n type: string,\n data_layout: \\NCHW\\|\\NHWC\\|null,\n dtype: string|null,\n ops: [ {op: string, ... op-specific fields ... } ],\n input_shape: [int|sym, ...] // OR \\inputs\\: [[...], [...]] for multi-input\n output_shape: [int|sym, ...],\n weights_fused: { name: [int|sym, ...], ... } | null,\n weights_original: { name: [int|sym, ...], ... } | null,\n count: int,\n where: string,\n source: { module: string, code: string }\n } )然后是详细说明和注意事项user_lines.append(Notes:) user_lines.append( - Treat any shape difference (inputs/outputs/weights) as a distinct subgraph. Count occurrences. ) user_lines.append( - Populate op-specific fields for conv/pool/linear, e.g., kernel_size/stride/padding/groups, bn_fused, output_size, start_dim. ) user_lines.append( - Include both weights_original (pre-fusion params like BN gamma/beta/running stats) and weights_fused (post-fusion conv/bias). Use null if not applicable. ) user_lines.append( - Provide a short \where\ string (e.g., Model.forward stem or layer2.block3.conv). ) user_lines.append( - Provide source with the smallest contiguous code snippet implementing the subgraph. ) user_lines.append( - Use data_layout and dtype when clear (default conv layout is NCHW). ) user_lines.append( - For binary ops like residual add, use inputs: [[...],[...]]. ) user_lines.append( - Prefer concrete integers from get_inputs() shapes in the problem; otherwise use symbols like B, H, W. )最后是输入代码示例user_lines.append(PROBLEM_FILE:\npython) user_lines.append(problem_code) user_lines.append() user_lines.append() user_lines.append(FUSED_CODE:) user_lines.append(python) user_lines.append(fused_code) user_lines.append() user_lines.append() user_lines.append( Now return only one fenced JSON block containing the array. No prose. )2.3 使用时机在 extract_subgraphs_to_json 函数中会调用 prompt# Ask LLM for shapes JSON system, user _build_llm_prompt_for_shapes(fused_code, problem_code) Temporary MUX to support Relay while we migrate to OpenAI Responses API. Uses EventAdapter for OpenAI, otherwise Provider inferface provider get_model_provider(model_name) if provider.name ! openai: # 直接调用提供商 result provider.get_response(...) else: # 通过 EventAdapter 流式处理 adapter EventAdapter(...) result adapter.stream(...)0x03 实现subgraph_extractor.py 实现了 KernelFalcon 的 “子图识别” 核心能力 —— 通过 Fuser 生成融合代码→LLM 解析代码提取子图→签名去重合并→输出标准化 JSON为后续 Triton 算子自动生成提供精准的子图粒度输入。3.1 特色LLM 驱动的智能子图识别放弃传统的 “静态代码解析 规则匹配”改用 LLM 理解 PyTorch 代码语义精准识别卷积 / 池化 / 线性层等算子的子图边界、形状、权重特征适配复杂的融合代码场景鲁棒的签名去重机制基于 “算子 输入 / 输出形状 权重结构 数据布局 数据类型” 构建稳定签名避免因命名 / 格式差异导致的重复子图保证子图识别的唯一性全链路容错设计针对 LLM 输出格式异常、JSON 解析失败、代码文件缺失等场景均有明确的容错逻辑和诊断文件输出提升工业级可用性标准化输出格式定义统一的子图 JSON Schema包含 id、类型、形状、权重、计数等核心字段为后续算子生成和模型优化提供标准化输入适配多 LLM 提供商兼容 OpenAI Responses API 和其他 LLM 提供商的接口通过适配层统一调用逻辑保证灵