diff --git a/src/gpu.cpp b/src/gpu.cpp index 00a711d09..7d5112bd2 100644 --- a/src/gpu.cpp +++ b/src/gpu.cpp @@ -3524,6 +3524,135 @@ VkShaderModule VulkanDevice::compile_shader_module(const uint32_t* spv_data, siz return shader_module; } +static void inject_fast_math(const uint32_t* code, size_t size, std::vector& dstcode) +{ + // 基本验证 + if (size < 20 || code[0] != 0x07230203) + { + dstcode.assign(code, code + size / sizeof(uint32_t)); + return; + } + + // ========================================================================= + // Pass 1: 分析 SPIR-V,收集所有必需的 ID 和锚点指针 + // ========================================================================= + uint32_t bound = code[3]; + uint32_t entry_point_id = 0; + uint32_t float32_type_id = 0; + uint32_t uint32_type_id = 0; + bool has_float_controls2_capability = false; + bool has_float_controls2_extension = false; + + const uint32_t* memory_model_ptr = nullptr; + const uint32_t* first_function_ptr = nullptr; + + const uint32_t* p = code + 5; + const uint32_t* end = code + (size / sizeof(uint32_t)); + + while (p < end) + { + uint16_t wordcount = p[0] >> 16; + if (wordcount == 0 || p + wordcount > end) break; // 安全检查 + uint16_t op = p[0] & 0xffff; + + switch (op) { + case 14: // OpMemoryModel + if (!memory_model_ptr) memory_model_ptr = p; + break; + case 15: // OpEntryPoint + if (p[1] == 5 /* GLCompute */) entry_point_id = p[2]; + break; + case 21: // OpTypeInt + if (wordcount == 4 && p[2] == 32 && p[3] == 0) uint32_type_id = p[1]; + break; + case 22: // OpTypeFloat + if (wordcount == 3 && p[2] == 32) float32_type_id = p[1]; + break; + case 54: // OpFunction + if (!first_function_ptr) first_function_ptr = p; + break; + case 2: // OpCapability + if (p[1] == 6029 /* FloatControls2 */) has_float_controls2_capability = true; + break; + case 10: // OpExtension + if (strcmp((const char*)&p[1], "SPV_KHR_float_controls2") == 0) has_float_controls2_extension = true; + break; + } + + // 如果找到了第一个函数,后面的内容无需再扫描以寻找锚点 + if (first_function_ptr) break; + + p += wordcount; + } + + // 如果缺少任何关键信息,则无法安全修改,返回原始代码 + if (entry_point_id == 0 || float32_type_id == 0 || uint32_type_id == 0 || !memory_model_ptr || !first_function_ptr) + { + dstcode.assign(code, code + size / sizeof(uint32_t)); + return; + } + + // ========================================================================= + // Pass 2: 使用锚点构建新的 SPIR-V + // ========================================================================= + dstcode.clear(); + dstcode.reserve(size / sizeof(uint32_t) + 20); + + // -- 准备新ID和数据 -- + uint32_t fast_math_constant_id = bound; + uint32_t new_bound = bound + 1; + const uint32_t fast_math_flags = 0x40000 /* AllowTransform */ | 0x20000 /* AllowReassoc */ | 0x10000 /* AllowContract */; + + // -- 写入新的头部,并更新Bound -- + dstcode.insert(dstcode.end(), code, code + 5); + dstcode[3] = new_bound; + + p = code + 5; + while (p < end) + { + uint16_t wordcount = p[0] >> 16; + if (wordcount == 0) break; + + // 在复制第一条 OpFunction 指令之前,注入 OpConstant + if (p == first_function_ptr) { + dstcode.push_back((4u << 16) | 43 /* OpConstant */); + dstcode.push_back(uint32_type_id); + dstcode.push_back(fast_math_constant_id); + dstcode.push_back(fast_math_flags); + } + + // 复制当前指令 + dstcode.insert(dstcode.end(), p, p + wordcount); + + // 在复制了锚点指令之后,注入新指令 + if (p == memory_model_ptr) + { + if (!has_float_controls2_capability) { + dstcode.push_back((2u << 16) | 2 /* OpCapability */); + dstcode.push_back(6029 /* FloatControls2 */); + } + if (!has_float_controls2_extension) { + const char ext_name[] = "SPV_KHR_float_controls2"; + size_t ext_word_count = (sizeof(ext_name) + 3) / 4; + dstcode.push_back(((ext_word_count + 1) << 16) | 10 /* OpExtension */); + std::vector ext_words(ext_word_count, 0); + memcpy(ext_words.data(), ext_name, sizeof(ext_name)); + dstcode.insert(dstcode.end(), ext_words.begin(), ext_words.end()); + } + } + else if ((p[0] & 0xffff) == 15 /* OpEntryPoint */ && p[2] == entry_point_id) + { + dstcode.push_back((5u << 16) | 16 /* OpExecutionMode */); + dstcode.push_back(entry_point_id); + dstcode.push_back(6028 /* FPFastMathDefault */); + dstcode.push_back(float32_type_id); + dstcode.push_back(fast_math_constant_id); + } + + p += wordcount; + } +} + static void inject_local_size_xyz(const uint32_t* code, size_t size, uint32_t local_size_x, uint32_t local_size_y, uint32_t local_size_z, uint32_t* dstcode, size_t* dstsize) { uint32_t local_size_x_id = -1; @@ -3627,6 +3756,23 @@ VkShaderModule VulkanDevice::compile_shader_module(const uint32_t* spv_data, siz size_t spv_data_size_modified = spv_data_size; inject_local_size_xyz(spv_data, spv_data_size, local_size_x, local_size_y, local_size_z, spv_data_modified, &spv_data_size_modified); + std::vector buffer; + inject_fast_math(spv_data_modified, spv_data_size_modified,buffer); + + // export to file + FILE* fp = fopen("shader.spv", "wb"); + if (fp) + { + fwrite(buffer.data(), sizeof(uint32_t), buffer.size(), fp); + fclose(fp); + } + FILE* fp2 = fopen("shader_modified.spv", "wb"); + if (fp2) + { + fwrite(spv_data_modified, sizeof(uint32_t), spv_data_size_modified / sizeof(uint32_t), fp2); + fclose(fp2); + } + VkShaderModule shader_module = compile_shader_module(spv_data_modified, spv_data_size_modified); free(spv_data_modified);