| @@ -3524,6 +3524,135 @@ VkShaderModule VulkanDevice::compile_shader_module(const uint32_t* spv_data, siz | |||
| return shader_module; | |||
| } | |||
| static void inject_fast_math(const uint32_t* code, size_t size, std::vector<uint32_t>& dstcode) | |||
| { | |||
| // 基本验证 | |||
| if (size < 20 || code[0] != 0x07230203) | |||
| { | |||
| dstcode.assign(code, code + size / sizeof(uint32_t)); | |||
| return; | |||
| } | |||
| // ========================================================================= | |||
| // Pass 1: 分析 SPIR-V,收集所有必需的 ID 和锚点指针 | |||
| // ========================================================================= | |||
| uint32_t bound = code[3]; | |||
| uint32_t entry_point_id = 0; | |||
| uint32_t float32_type_id = 0; | |||
| uint32_t uint32_type_id = 0; | |||
| bool has_float_controls2_capability = false; | |||
| bool has_float_controls2_extension = false; | |||
| const uint32_t* memory_model_ptr = nullptr; | |||
| const uint32_t* first_function_ptr = nullptr; | |||
| const uint32_t* p = code + 5; | |||
| const uint32_t* end = code + (size / sizeof(uint32_t)); | |||
| while (p < end) | |||
| { | |||
| uint16_t wordcount = p[0] >> 16; | |||
| if (wordcount == 0 || p + wordcount > end) break; // 安全检查 | |||
| uint16_t op = p[0] & 0xffff; | |||
| switch (op) { | |||
| case 14: // OpMemoryModel | |||
| if (!memory_model_ptr) memory_model_ptr = p; | |||
| break; | |||
| case 15: // OpEntryPoint | |||
| if (p[1] == 5 /* GLCompute */) entry_point_id = p[2]; | |||
| break; | |||
| case 21: // OpTypeInt | |||
| if (wordcount == 4 && p[2] == 32 && p[3] == 0) uint32_type_id = p[1]; | |||
| break; | |||
| case 22: // OpTypeFloat | |||
| if (wordcount == 3 && p[2] == 32) float32_type_id = p[1]; | |||
| break; | |||
| case 54: // OpFunction | |||
| if (!first_function_ptr) first_function_ptr = p; | |||
| break; | |||
| case 2: // OpCapability | |||
| if (p[1] == 6029 /* FloatControls2 */) has_float_controls2_capability = true; | |||
| break; | |||
| case 10: // OpExtension | |||
| if (strcmp((const char*)&p[1], "SPV_KHR_float_controls2") == 0) has_float_controls2_extension = true; | |||
| break; | |||
| } | |||
| // 如果找到了第一个函数,后面的内容无需再扫描以寻找锚点 | |||
| if (first_function_ptr) break; | |||
| p += wordcount; | |||
| } | |||
| // 如果缺少任何关键信息,则无法安全修改,返回原始代码 | |||
| if (entry_point_id == 0 || float32_type_id == 0 || uint32_type_id == 0 || !memory_model_ptr || !first_function_ptr) | |||
| { | |||
| dstcode.assign(code, code + size / sizeof(uint32_t)); | |||
| return; | |||
| } | |||
| // ========================================================================= | |||
| // Pass 2: 使用锚点构建新的 SPIR-V | |||
| // ========================================================================= | |||
| dstcode.clear(); | |||
| dstcode.reserve(size / sizeof(uint32_t) + 20); | |||
| // -- 准备新ID和数据 -- | |||
| uint32_t fast_math_constant_id = bound; | |||
| uint32_t new_bound = bound + 1; | |||
| const uint32_t fast_math_flags = 0x40000 /* AllowTransform */ | 0x20000 /* AllowReassoc */ | 0x10000 /* AllowContract */; | |||
| // -- 写入新的头部,并更新Bound -- | |||
| dstcode.insert(dstcode.end(), code, code + 5); | |||
| dstcode[3] = new_bound; | |||
| p = code + 5; | |||
| while (p < end) | |||
| { | |||
| uint16_t wordcount = p[0] >> 16; | |||
| if (wordcount == 0) break; | |||
| // 在复制第一条 OpFunction 指令之前,注入 OpConstant | |||
| if (p == first_function_ptr) { | |||
| dstcode.push_back((4u << 16) | 43 /* OpConstant */); | |||
| dstcode.push_back(uint32_type_id); | |||
| dstcode.push_back(fast_math_constant_id); | |||
| dstcode.push_back(fast_math_flags); | |||
| } | |||
| // 复制当前指令 | |||
| dstcode.insert(dstcode.end(), p, p + wordcount); | |||
| // 在复制了锚点指令之后,注入新指令 | |||
| if (p == memory_model_ptr) | |||
| { | |||
| if (!has_float_controls2_capability) { | |||
| dstcode.push_back((2u << 16) | 2 /* OpCapability */); | |||
| dstcode.push_back(6029 /* FloatControls2 */); | |||
| } | |||
| if (!has_float_controls2_extension) { | |||
| const char ext_name[] = "SPV_KHR_float_controls2"; | |||
| size_t ext_word_count = (sizeof(ext_name) + 3) / 4; | |||
| dstcode.push_back(((ext_word_count + 1) << 16) | 10 /* OpExtension */); | |||
| std::vector<uint32_t> ext_words(ext_word_count, 0); | |||
| memcpy(ext_words.data(), ext_name, sizeof(ext_name)); | |||
| dstcode.insert(dstcode.end(), ext_words.begin(), ext_words.end()); | |||
| } | |||
| } | |||
| else if ((p[0] & 0xffff) == 15 /* OpEntryPoint */ && p[2] == entry_point_id) | |||
| { | |||
| dstcode.push_back((5u << 16) | 16 /* OpExecutionMode */); | |||
| dstcode.push_back(entry_point_id); | |||
| dstcode.push_back(6028 /* FPFastMathDefault */); | |||
| dstcode.push_back(float32_type_id); | |||
| dstcode.push_back(fast_math_constant_id); | |||
| } | |||
| p += wordcount; | |||
| } | |||
| } | |||
| static void inject_local_size_xyz(const uint32_t* code, size_t size, uint32_t local_size_x, uint32_t local_size_y, uint32_t local_size_z, uint32_t* dstcode, size_t* dstsize) | |||
| { | |||
| uint32_t local_size_x_id = -1; | |||
| @@ -3627,6 +3756,23 @@ VkShaderModule VulkanDevice::compile_shader_module(const uint32_t* spv_data, siz | |||
| size_t spv_data_size_modified = spv_data_size; | |||
| inject_local_size_xyz(spv_data, spv_data_size, local_size_x, local_size_y, local_size_z, spv_data_modified, &spv_data_size_modified); | |||
| std::vector<uint32_t> buffer; | |||
| inject_fast_math(spv_data_modified, spv_data_size_modified,buffer); | |||
| // export to file | |||
| FILE* fp = fopen("shader.spv", "wb"); | |||
| if (fp) | |||
| { | |||
| fwrite(buffer.data(), sizeof(uint32_t), buffer.size(), fp); | |||
| fclose(fp); | |||
| } | |||
| FILE* fp2 = fopen("shader_modified.spv", "wb"); | |||
| if (fp2) | |||
| { | |||
| fwrite(spv_data_modified, sizeof(uint32_t), spv_data_size_modified / sizeof(uint32_t), fp2); | |||
| fclose(fp2); | |||
| } | |||
| VkShaderModule shader_module = compile_shader_module(spv_data_modified, spv_data_size_modified); | |||
| free(spv_data_modified); | |||