Browse Source

feat: init for fast_math inject

pull/6223/head
ice 11 months ago
parent
commit
17669e5998
1 changed files with 146 additions and 0 deletions
  1. +146
    -0
      src/gpu.cpp

+ 146
- 0
src/gpu.cpp View File

@@ -3524,6 +3524,135 @@ VkShaderModule VulkanDevice::compile_shader_module(const uint32_t* spv_data, siz
return shader_module;
}

static void inject_fast_math(const uint32_t* code, size_t size, std::vector<uint32_t>& dstcode)
{
// 基本验证
if (size < 20 || code[0] != 0x07230203)
{
dstcode.assign(code, code + size / sizeof(uint32_t));
return;
}

// =========================================================================
// Pass 1: 分析 SPIR-V,收集所有必需的 ID 和锚点指针
// =========================================================================
uint32_t bound = code[3];
uint32_t entry_point_id = 0;
uint32_t float32_type_id = 0;
uint32_t uint32_type_id = 0;
bool has_float_controls2_capability = false;
bool has_float_controls2_extension = false;

const uint32_t* memory_model_ptr = nullptr;
const uint32_t* first_function_ptr = nullptr;

const uint32_t* p = code + 5;
const uint32_t* end = code + (size / sizeof(uint32_t));

while (p < end)
{
uint16_t wordcount = p[0] >> 16;
if (wordcount == 0 || p + wordcount > end) break; // 安全检查
uint16_t op = p[0] & 0xffff;

switch (op) {
case 14: // OpMemoryModel
if (!memory_model_ptr) memory_model_ptr = p;
break;
case 15: // OpEntryPoint
if (p[1] == 5 /* GLCompute */) entry_point_id = p[2];
break;
case 21: // OpTypeInt
if (wordcount == 4 && p[2] == 32 && p[3] == 0) uint32_type_id = p[1];
break;
case 22: // OpTypeFloat
if (wordcount == 3 && p[2] == 32) float32_type_id = p[1];
break;
case 54: // OpFunction
if (!first_function_ptr) first_function_ptr = p;
break;
case 2: // OpCapability
if (p[1] == 6029 /* FloatControls2 */) has_float_controls2_capability = true;
break;
case 10: // OpExtension
if (strcmp((const char*)&p[1], "SPV_KHR_float_controls2") == 0) has_float_controls2_extension = true;
break;
}

// 如果找到了第一个函数,后面的内容无需再扫描以寻找锚点
if (first_function_ptr) break;

p += wordcount;
}

// 如果缺少任何关键信息,则无法安全修改,返回原始代码
if (entry_point_id == 0 || float32_type_id == 0 || uint32_type_id == 0 || !memory_model_ptr || !first_function_ptr)
{
dstcode.assign(code, code + size / sizeof(uint32_t));
return;
}

// =========================================================================
// Pass 2: 使用锚点构建新的 SPIR-V
// =========================================================================
dstcode.clear();
dstcode.reserve(size / sizeof(uint32_t) + 20);

// -- 准备新ID和数据 --
uint32_t fast_math_constant_id = bound;
uint32_t new_bound = bound + 1;
const uint32_t fast_math_flags = 0x40000 /* AllowTransform */ | 0x20000 /* AllowReassoc */ | 0x10000 /* AllowContract */;

// -- 写入新的头部,并更新Bound --
dstcode.insert(dstcode.end(), code, code + 5);
dstcode[3] = new_bound;

p = code + 5;
while (p < end)
{
uint16_t wordcount = p[0] >> 16;
if (wordcount == 0) break;

// 在复制第一条 OpFunction 指令之前,注入 OpConstant
if (p == first_function_ptr) {
dstcode.push_back((4u << 16) | 43 /* OpConstant */);
dstcode.push_back(uint32_type_id);
dstcode.push_back(fast_math_constant_id);
dstcode.push_back(fast_math_flags);
}

// 复制当前指令
dstcode.insert(dstcode.end(), p, p + wordcount);

// 在复制了锚点指令之后,注入新指令
if (p == memory_model_ptr)
{
if (!has_float_controls2_capability) {
dstcode.push_back((2u << 16) | 2 /* OpCapability */);
dstcode.push_back(6029 /* FloatControls2 */);
}
if (!has_float_controls2_extension) {
const char ext_name[] = "SPV_KHR_float_controls2";
size_t ext_word_count = (sizeof(ext_name) + 3) / 4;
dstcode.push_back(((ext_word_count + 1) << 16) | 10 /* OpExtension */);
std::vector<uint32_t> ext_words(ext_word_count, 0);
memcpy(ext_words.data(), ext_name, sizeof(ext_name));
dstcode.insert(dstcode.end(), ext_words.begin(), ext_words.end());
}
}
else if ((p[0] & 0xffff) == 15 /* OpEntryPoint */ && p[2] == entry_point_id)
{
dstcode.push_back((5u << 16) | 16 /* OpExecutionMode */);
dstcode.push_back(entry_point_id);
dstcode.push_back(6028 /* FPFastMathDefault */);
dstcode.push_back(float32_type_id);
dstcode.push_back(fast_math_constant_id);
}

p += wordcount;
}
}

static void inject_local_size_xyz(const uint32_t* code, size_t size, uint32_t local_size_x, uint32_t local_size_y, uint32_t local_size_z, uint32_t* dstcode, size_t* dstsize)
{
uint32_t local_size_x_id = -1;
@@ -3627,6 +3756,23 @@ VkShaderModule VulkanDevice::compile_shader_module(const uint32_t* spv_data, siz
size_t spv_data_size_modified = spv_data_size;
inject_local_size_xyz(spv_data, spv_data_size, local_size_x, local_size_y, local_size_z, spv_data_modified, &spv_data_size_modified);

std::vector<uint32_t> buffer;
inject_fast_math(spv_data_modified, spv_data_size_modified,buffer);

// export to file
FILE* fp = fopen("shader.spv", "wb");
if (fp)
{
fwrite(buffer.data(), sizeof(uint32_t), buffer.size(), fp);
fclose(fp);
}
FILE* fp2 = fopen("shader_modified.spv", "wb");
if (fp2)
{
fwrite(spv_data_modified, sizeof(uint32_t), spv_data_size_modified / sizeof(uint32_t), fp2);
fclose(fp2);
}

VkShaderModule shader_module = compile_shader_module(spv_data_modified, spv_data_size_modified);

free(spv_data_modified);


Loading…
Cancel
Save