Browse Source

fix: missing fast_math_flag in the Android code

pull/6223/head
ice 11 months ago
parent
commit
2e7b13bd57
3 changed files with 11 additions and 13 deletions
  1. +1
    -1
      src/gpu.h
  2. +1
    -1
      src/pipeline.cpp
  3. +9
    -11
      tests/test_fast_math.cpp

+ 1
- 1
src/gpu.h View File

@@ -414,7 +414,7 @@ public:
VkShaderModule compile_shader_module(const uint32_t* spv_data, size_t spv_data_size) const;

// with fixed workgroup size
VkShaderModule compile_shader_module(const uint32_t* spv_data, size_t spv_data_size, uint32_t local_size_x, uint32_t local_size_y, uint32_t local_size_z, uint32_t fast_math_flag) const;
VkShaderModule compile_shader_module(const uint32_t* spv_data, size_t spv_data_size, uint32_t local_size_x, uint32_t local_size_y, uint32_t local_size_z, uint32_t fast_math_flag = 0) const;

// helper for creating pipeline
int create_descriptorset_layout(int binding_count, const int* binding_types, VkDescriptorSetLayout* descriptorset_layout) const;


+ 1
- 1
src/pipeline.cpp View File

@@ -461,7 +461,7 @@ int ImportAndroidHardwareBufferPipeline::create_shader_module(const Option& opt)

set_shader_info(shader_info);

VkShaderModule shader_module = vkdev->compile_shader_module(spv_data, spv_data_size, local_size_x(), local_size_y(), local_size_z());
VkShaderModule shader_module = vkdev->compile_shader_module(spv_data, spv_data_size, local_size_x(), local_size_y(), local_size_z(), opt.fast_math_flag);
set_shader_module(shader_module);

return 0;


+ 9
- 11
tests/test_fast_math.cpp View File

@@ -36,7 +36,7 @@ static int test_vulkan_fast_math()
{
// Define model path based on environment
// Create a random input matrix
ncnn::Mat input = RandomMat(512, 512, 3);
ncnn::Mat input = RandomMat(224, 224, 3);
DataReaderFromEmpty dr;

#ifdef __EMSCRIPTEN__
@@ -58,7 +58,7 @@ static int test_vulkan_fast_math()
net_default.opt.use_fp16_storage = false;
net_default.opt.use_fp16_packed = false;

net_default.load_param(MODEL_DIR "/vision_transformer.param");
net_default.load_param(MODEL_DIR "/resnet50.param");
net_default.load_model(dr);
printf("Default net loaded successfully.\n");

@@ -70,16 +70,14 @@ static int test_vulkan_fast_math()
printf("==================================================\n");
ncnn::Net net_fast_math;
net_fast_math.opt.use_vulkan_compute = true;
net_fast_math.opt.vk_fast_math_flag = ncnn::Option::VK_FAST_MATH_FLAG_Fast
| ncnn::Option::VK_FAST_MATH_FLAG_AllowContract
| ncnn::Option::VK_FAST_MATH_FLAG_AllowReassoc
| ncnn::Option::VK_FAST_MATH_FLAG_AllowTransform;
net_fast_math.opt.vk_fast_math_flag = ncnn::Option::VK_FAST_MATH_FLAG_AllowContract;

net_fast_math.opt.vulkan_device_index = device_index;
net_fast_math.opt.use_fp16_arithmetic = false;
net_fast_math.opt.use_fp16_packed = false;
net_fast_math.opt.use_fp16_storage = false;

net_fast_math.load_param(MODEL_DIR "/vision_transformer.param");
net_fast_math.load_param(MODEL_DIR "/resnet50.param");
net_fast_math.load_model(dr);
printf("Fast math net loaded successfully.\n");

@@ -92,12 +90,12 @@ static int test_vulkan_fast_math()
ncnn::Mat output_default, output_fast_math;
{
ncnn::Extractor ex = net_default.create_extractor();
ex.input("input", input);
ex.input("data", input);
ex.extract("output", output_default);
}
{
ncnn::Extractor ex = net_fast_math.create_extractor();
ex.input("input", input);
ex.input("data", input);
ex.extract("output", output_fast_math);
}
printf("Warm-up complete.\n");
@@ -118,7 +116,7 @@ static int test_vulkan_fast_math()
for (int i = 0; i < loop_count; i++)
{
ncnn::Extractor ex = net_default.create_extractor();
ex.input("input", input);
ex.input("data", input);
ex.extract("output", output_default);
}
double end = ncnn::get_current_time();
@@ -132,7 +130,7 @@ static int test_vulkan_fast_math()
for (int i = 0; i < loop_count; i++)
{
ncnn::Extractor ex = net_fast_math.create_extractor();
ex.input("input", input);
ex.input("data", input);
ex.extract("output", output_fast_math);
}
double end = ncnn::get_current_time();


Loading…
Cancel
Save