From e14cc272ac8e4f2251f2e01d2dfe6844c4f3db32 Mon Sep 17 00:00:00 2001 From: FhqTreap <45459183+FhqTreap@users.noreply.github.com> Date: Wed, 6 Sep 2023 14:47:20 +0800 Subject: [PATCH] gelu vk op tanh fix (#5008) --- src/layer/vulkan/shader/gelu.comp | 5 +++++ src/layer/vulkan/shader/gelu_pack4.comp | 5 +++++ src/layer/vulkan/shader/gelu_pack8.comp | 6 ++++++ 3 files changed, 16 insertions(+) diff --git a/src/layer/vulkan/shader/gelu.comp b/src/layer/vulkan/shader/gelu.comp index dfbe5486e..f389101d1 100644 --- a/src/layer/vulkan/shader/gelu.comp +++ b/src/layer/vulkan/shader/gelu.comp @@ -62,7 +62,12 @@ void main() #endif // y = 0.5x * (1 + tanh(sqrt(2/Pi) * (x + 0.044715x^3))) + +#if NCNN_moltenvk + v = 0.5f * v * (1.0f + afp(tanh(float(0.79788452f * (v + 0.044715f * v * v * v))))); +#else v = 0.5f * v * (1.0f + tanh(0.79788452f * (v + 0.044715f * v * v * v))); +#endif #if NCNN_image_shader image3d_st1(top_blob_3d, ivec3(gx, gy, gz), v); diff --git a/src/layer/vulkan/shader/gelu_pack4.comp b/src/layer/vulkan/shader/gelu_pack4.comp index 764f04458..3d9ee1bf0 100644 --- a/src/layer/vulkan/shader/gelu_pack4.comp +++ b/src/layer/vulkan/shader/gelu_pack4.comp @@ -62,7 +62,12 @@ void main() #endif // y = 0.5x * (1 + tanh(sqrt(2/Pi) * (x + 0.044715x^3))) + +#if NCNN_moltenvk + v = 0.5f * v * (1.0f + afpvec4(tanh(vec4(0.79788452f * (v + 0.044715f * v * v * v))))); +#else v = 0.5f * v * (1.0f + tanh(0.79788452f * (v + 0.044715f * v * v * v))); +#endif #if NCNN_image_shader image3d_st4(top_blob_3d, ivec3(gx, gy, gz), v); diff --git a/src/layer/vulkan/shader/gelu_pack8.comp b/src/layer/vulkan/shader/gelu_pack8.comp index 02ce389c4..47d181147 100644 --- a/src/layer/vulkan/shader/gelu_pack8.comp +++ b/src/layer/vulkan/shader/gelu_pack8.comp @@ -63,8 +63,14 @@ void main() #endif // y = 0.5x * (1 + tanh(sqrt(2/Pi) * (x + 0.044715x^3))) + +#if NCNN_moltenvk + v[0] = 0.5f * v[0] * (1.0f + afpvec4(tanh(vec4(0.79788452f * (v[0] + 0.044715f * v[0] * v[0] * v[0]))))); + v[1] = 0.5f * v[1] * (1.0f + afpvec4(tanh(vec4(0.79788452f * (v[1] + 0.044715f * v[1] * v[1] * v[1]))))); +#else v[0] = 0.5f * v[0] * (1.0f + tanh(0.79788452f * (v[0] + 0.044715f * v[0] * v[0] * v[0]))); v[1] = 0.5f * v[1] * (1.0f + tanh(0.79788452f * (v[1] + 0.044715f * v[1] * v[1] * v[1]))); +#endif #if NCNN_image_shader image3d_st8(top_blob_3d, ivec3(gx, gy, gz), v);