| @@ -23,6 +23,7 @@ | |||||
| #include <string.h> | #include <string.h> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <string> | |||||
| #include <vector> | #include <vector> | ||||
| #include "mat.h" | #include "mat.h" | ||||
| @@ -776,46 +777,44 @@ VulkanDevice::VulkanDevice(int device_index) : info(g_gpu_infos[device_index]) | |||||
| enabledExtensions.push_back("VK_KHR_storage_buffer_storage_class"); | enabledExtensions.push_back("VK_KHR_storage_buffer_storage_class"); | ||||
| void* enabledExtensionFeatures = 0; | void* enabledExtensionFeatures = 0; | ||||
| if (support_VK_KHR_get_physical_device_properties2) | |||||
| // enable int8 storage | |||||
| VkPhysicalDevice8BitStorageFeaturesKHR enabled8BitStorageFeatures; | |||||
| enabled8BitStorageFeatures.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES_KHR; | |||||
| enabled8BitStorageFeatures.pNext = 0; | |||||
| enabled8BitStorageFeatures.storageBuffer8BitAccess = info.support_int8_storage; | |||||
| enabled8BitStorageFeatures.uniformAndStorageBuffer8BitAccess = VK_FALSE; | |||||
| enabled8BitStorageFeatures.storagePushConstant8 = VK_FALSE; | |||||
| if (support_VK_KHR_get_physical_device_properties2 && info.support_VK_KHR_8bit_storage) | |||||
| { | { | ||||
| // enable int8 storage | |||||
| VkPhysicalDevice8BitStorageFeaturesKHR enabled8BitStorageFeatures; | |||||
| enabled8BitStorageFeatures.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES_KHR; | |||||
| enabled8BitStorageFeatures.pNext = 0; | |||||
| enabled8BitStorageFeatures.storageBuffer8BitAccess = info.support_int8_storage; | |||||
| enabled8BitStorageFeatures.uniformAndStorageBuffer8BitAccess = VK_FALSE; | |||||
| enabled8BitStorageFeatures.storagePushConstant8 = VK_FALSE; | |||||
| if (info.support_VK_KHR_8bit_storage) | |||||
| { | |||||
| enabled8BitStorageFeatures.pNext = enabledExtensionFeatures; | |||||
| enabledExtensionFeatures = &enabled8BitStorageFeatures; | |||||
| } | |||||
| enabled8BitStorageFeatures.pNext = enabledExtensionFeatures; | |||||
| enabledExtensionFeatures = &enabled8BitStorageFeatures; | |||||
| } | |||||
| // enable fp16/int16 storage | |||||
| VkPhysicalDevice16BitStorageFeaturesKHR enabled16BitStorageFeatures; | |||||
| enabled16BitStorageFeatures.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES_KHR; | |||||
| enabled16BitStorageFeatures.pNext = 0; | |||||
| enabled16BitStorageFeatures.storageBuffer16BitAccess = info.support_fp16_storage; | |||||
| enabled16BitStorageFeatures.uniformAndStorageBuffer16BitAccess = VK_FALSE; | |||||
| enabled16BitStorageFeatures.storagePushConstant16 = VK_FALSE; | |||||
| enabled16BitStorageFeatures.storageInputOutput16 = VK_FALSE; | |||||
| if (info.support_VK_KHR_16bit_storage) | |||||
| { | |||||
| enabled16BitStorageFeatures.pNext = enabledExtensionFeatures; | |||||
| enabledExtensionFeatures = &enabled16BitStorageFeatures; | |||||
| } | |||||
| // enable fp16/int16 storage | |||||
| VkPhysicalDevice16BitStorageFeaturesKHR enabled16BitStorageFeatures; | |||||
| enabled16BitStorageFeatures.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES_KHR; | |||||
| enabled16BitStorageFeatures.pNext = 0; | |||||
| enabled16BitStorageFeatures.storageBuffer16BitAccess = info.support_fp16_storage; | |||||
| enabled16BitStorageFeatures.uniformAndStorageBuffer16BitAccess = VK_FALSE; | |||||
| enabled16BitStorageFeatures.storagePushConstant16 = VK_FALSE; | |||||
| enabled16BitStorageFeatures.storageInputOutput16 = VK_FALSE; | |||||
| if (support_VK_KHR_get_physical_device_properties2 && info.support_VK_KHR_16bit_storage) | |||||
| { | |||||
| enabled16BitStorageFeatures.pNext = enabledExtensionFeatures; | |||||
| enabledExtensionFeatures = &enabled16BitStorageFeatures; | |||||
| } | |||||
| // enable fp16/int8 arithmetic | |||||
| VkPhysicalDeviceFloat16Int8FeaturesKHR enabledFloat16Int8Features; | |||||
| enabledFloat16Int8Features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FLOAT16_INT8_FEATURES_KHR; | |||||
| enabledFloat16Int8Features.pNext = 0; | |||||
| enabledFloat16Int8Features.shaderFloat16 = info.support_fp16_arithmetic; | |||||
| enabledFloat16Int8Features.shaderInt8 = info.support_int8_arithmetic; | |||||
| if (info.support_VK_KHR_shader_float16_int8) | |||||
| { | |||||
| enabledFloat16Int8Features.pNext = enabledExtensionFeatures; | |||||
| enabledExtensionFeatures = &enabledFloat16Int8Features; | |||||
| } | |||||
| // enable fp16/int8 arithmetic | |||||
| VkPhysicalDeviceFloat16Int8FeaturesKHR enabledFloat16Int8Features; | |||||
| enabledFloat16Int8Features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FLOAT16_INT8_FEATURES_KHR; | |||||
| enabledFloat16Int8Features.pNext = 0; | |||||
| enabledFloat16Int8Features.shaderFloat16 = info.support_fp16_arithmetic; | |||||
| enabledFloat16Int8Features.shaderInt8 = info.support_int8_arithmetic; | |||||
| if (support_VK_KHR_get_physical_device_properties2 && info.support_VK_KHR_shader_float16_int8) | |||||
| { | |||||
| enabledFloat16Int8Features.pNext = enabledExtensionFeatures; | |||||
| enabledExtensionFeatures = &enabledFloat16Int8Features; | |||||
| } | } | ||||
| VkDeviceQueueCreateInfo deviceQueueCreateInfos[2]; | VkDeviceQueueCreateInfo deviceQueueCreateInfos[2]; | ||||
| @@ -875,15 +874,28 @@ VulkanDevice::~VulkanDevice() | |||||
| vkDestroyDevice(device, 0); | vkDestroyDevice(device, 0); | ||||
| } | } | ||||
| VkShaderModule VulkanDevice::get_shader_module(const char* name) const | |||||
| VkShaderModule VulkanDevice::get_shader_module(const char* _name) const | |||||
| { | { | ||||
| std::string name = _name; | |||||
| // if (info.support_fp16_arithmetic) | |||||
| // { | |||||
| // name += "_fp16a"; | |||||
| // } | |||||
| // else if (info.support_fp16_storage) | |||||
| // { | |||||
| // name += "_fp16s"; | |||||
| // } | |||||
| // | |||||
| // fprintf(stderr, "get_shader_module %s\n", name.c_str()); | |||||
| for (int i=0; i<layer_shader_registry_entry_count; i++) | for (int i=0; i<layer_shader_registry_entry_count; i++) | ||||
| { | { | ||||
| if (strcmp(layer_shader_registry[i].name, name) == 0) | |||||
| if (strcmp(layer_shader_registry[i].name, name.c_str()) == 0) | |||||
| return shader_modules[i]; | return shader_modules[i]; | ||||
| } | } | ||||
| fprintf(stderr, "no such shader module %s\n", name); | |||||
| fprintf(stderr, "no such shader module %s\n", name.c_str()); | |||||
| return 0; | return 0; | ||||
| } | } | ||||