| @@ -3524,6 +3524,135 @@ VkShaderModule VulkanDevice::compile_shader_module(const uint32_t* spv_data, siz | |||||
| return shader_module; | return shader_module; | ||||
| } | } | ||||
| static void inject_fast_math(const uint32_t* code, size_t size, std::vector<uint32_t>& dstcode) | |||||
| { | |||||
| // 基本验证 | |||||
| if (size < 20 || code[0] != 0x07230203) | |||||
| { | |||||
| dstcode.assign(code, code + size / sizeof(uint32_t)); | |||||
| return; | |||||
| } | |||||
| // ========================================================================= | |||||
| // Pass 1: 分析 SPIR-V,收集所有必需的 ID 和锚点指针 | |||||
| // ========================================================================= | |||||
| uint32_t bound = code[3]; | |||||
| uint32_t entry_point_id = 0; | |||||
| uint32_t float32_type_id = 0; | |||||
| uint32_t uint32_type_id = 0; | |||||
| bool has_float_controls2_capability = false; | |||||
| bool has_float_controls2_extension = false; | |||||
| const uint32_t* memory_model_ptr = nullptr; | |||||
| const uint32_t* first_function_ptr = nullptr; | |||||
| const uint32_t* p = code + 5; | |||||
| const uint32_t* end = code + (size / sizeof(uint32_t)); | |||||
| while (p < end) | |||||
| { | |||||
| uint16_t wordcount = p[0] >> 16; | |||||
| if (wordcount == 0 || p + wordcount > end) break; // 安全检查 | |||||
| uint16_t op = p[0] & 0xffff; | |||||
| switch (op) { | |||||
| case 14: // OpMemoryModel | |||||
| if (!memory_model_ptr) memory_model_ptr = p; | |||||
| break; | |||||
| case 15: // OpEntryPoint | |||||
| if (p[1] == 5 /* GLCompute */) entry_point_id = p[2]; | |||||
| break; | |||||
| case 21: // OpTypeInt | |||||
| if (wordcount == 4 && p[2] == 32 && p[3] == 0) uint32_type_id = p[1]; | |||||
| break; | |||||
| case 22: // OpTypeFloat | |||||
| if (wordcount == 3 && p[2] == 32) float32_type_id = p[1]; | |||||
| break; | |||||
| case 54: // OpFunction | |||||
| if (!first_function_ptr) first_function_ptr = p; | |||||
| break; | |||||
| case 2: // OpCapability | |||||
| if (p[1] == 6029 /* FloatControls2 */) has_float_controls2_capability = true; | |||||
| break; | |||||
| case 10: // OpExtension | |||||
| if (strcmp((const char*)&p[1], "SPV_KHR_float_controls2") == 0) has_float_controls2_extension = true; | |||||
| break; | |||||
| } | |||||
| // 如果找到了第一个函数,后面的内容无需再扫描以寻找锚点 | |||||
| if (first_function_ptr) break; | |||||
| p += wordcount; | |||||
| } | |||||
| // 如果缺少任何关键信息,则无法安全修改,返回原始代码 | |||||
| if (entry_point_id == 0 || float32_type_id == 0 || uint32_type_id == 0 || !memory_model_ptr || !first_function_ptr) | |||||
| { | |||||
| dstcode.assign(code, code + size / sizeof(uint32_t)); | |||||
| return; | |||||
| } | |||||
| // ========================================================================= | |||||
| // Pass 2: 使用锚点构建新的 SPIR-V | |||||
| // ========================================================================= | |||||
| dstcode.clear(); | |||||
| dstcode.reserve(size / sizeof(uint32_t) + 20); | |||||
| // -- 准备新ID和数据 -- | |||||
| uint32_t fast_math_constant_id = bound; | |||||
| uint32_t new_bound = bound + 1; | |||||
| const uint32_t fast_math_flags = 0x40000 /* AllowTransform */ | 0x20000 /* AllowReassoc */ | 0x10000 /* AllowContract */; | |||||
| // -- 写入新的头部,并更新Bound -- | |||||
| dstcode.insert(dstcode.end(), code, code + 5); | |||||
| dstcode[3] = new_bound; | |||||
| p = code + 5; | |||||
| while (p < end) | |||||
| { | |||||
| uint16_t wordcount = p[0] >> 16; | |||||
| if (wordcount == 0) break; | |||||
| // 在复制第一条 OpFunction 指令之前,注入 OpConstant | |||||
| if (p == first_function_ptr) { | |||||
| dstcode.push_back((4u << 16) | 43 /* OpConstant */); | |||||
| dstcode.push_back(uint32_type_id); | |||||
| dstcode.push_back(fast_math_constant_id); | |||||
| dstcode.push_back(fast_math_flags); | |||||
| } | |||||
| // 复制当前指令 | |||||
| dstcode.insert(dstcode.end(), p, p + wordcount); | |||||
| // 在复制了锚点指令之后,注入新指令 | |||||
| if (p == memory_model_ptr) | |||||
| { | |||||
| if (!has_float_controls2_capability) { | |||||
| dstcode.push_back((2u << 16) | 2 /* OpCapability */); | |||||
| dstcode.push_back(6029 /* FloatControls2 */); | |||||
| } | |||||
| if (!has_float_controls2_extension) { | |||||
| const char ext_name[] = "SPV_KHR_float_controls2"; | |||||
| size_t ext_word_count = (sizeof(ext_name) + 3) / 4; | |||||
| dstcode.push_back(((ext_word_count + 1) << 16) | 10 /* OpExtension */); | |||||
| std::vector<uint32_t> ext_words(ext_word_count, 0); | |||||
| memcpy(ext_words.data(), ext_name, sizeof(ext_name)); | |||||
| dstcode.insert(dstcode.end(), ext_words.begin(), ext_words.end()); | |||||
| } | |||||
| } | |||||
| else if ((p[0] & 0xffff) == 15 /* OpEntryPoint */ && p[2] == entry_point_id) | |||||
| { | |||||
| dstcode.push_back((5u << 16) | 16 /* OpExecutionMode */); | |||||
| dstcode.push_back(entry_point_id); | |||||
| dstcode.push_back(6028 /* FPFastMathDefault */); | |||||
| dstcode.push_back(float32_type_id); | |||||
| dstcode.push_back(fast_math_constant_id); | |||||
| } | |||||
| p += wordcount; | |||||
| } | |||||
| } | |||||
| static void inject_local_size_xyz(const uint32_t* code, size_t size, uint32_t local_size_x, uint32_t local_size_y, uint32_t local_size_z, uint32_t* dstcode, size_t* dstsize) | static void inject_local_size_xyz(const uint32_t* code, size_t size, uint32_t local_size_x, uint32_t local_size_y, uint32_t local_size_z, uint32_t* dstcode, size_t* dstsize) | ||||
| { | { | ||||
| uint32_t local_size_x_id = -1; | uint32_t local_size_x_id = -1; | ||||
| @@ -3627,6 +3756,23 @@ VkShaderModule VulkanDevice::compile_shader_module(const uint32_t* spv_data, siz | |||||
| size_t spv_data_size_modified = spv_data_size; | size_t spv_data_size_modified = spv_data_size; | ||||
| inject_local_size_xyz(spv_data, spv_data_size, local_size_x, local_size_y, local_size_z, spv_data_modified, &spv_data_size_modified); | inject_local_size_xyz(spv_data, spv_data_size, local_size_x, local_size_y, local_size_z, spv_data_modified, &spv_data_size_modified); | ||||
| std::vector<uint32_t> buffer; | |||||
| inject_fast_math(spv_data_modified, spv_data_size_modified,buffer); | |||||
| // export to file | |||||
| FILE* fp = fopen("shader.spv", "wb"); | |||||
| if (fp) | |||||
| { | |||||
| fwrite(buffer.data(), sizeof(uint32_t), buffer.size(), fp); | |||||
| fclose(fp); | |||||
| } | |||||
| FILE* fp2 = fopen("shader_modified.spv", "wb"); | |||||
| if (fp2) | |||||
| { | |||||
| fwrite(spv_data_modified, sizeof(uint32_t), spv_data_size_modified / sizeof(uint32_t), fp2); | |||||
| fclose(fp2); | |||||
| } | |||||
| VkShaderModule shader_module = compile_shader_module(spv_data_modified, spv_data_size_modified); | VkShaderModule shader_module = compile_shader_module(spv_data_modified, spv_data_size_modified); | ||||
| free(spv_data_modified); | free(spv_data_modified); | ||||