Browse Source

improve vulkan winograd f43 fp16 numerical stability (#4492)

tags/20230223
nihui GitHub 3 years ago
parent
commit
dfbcd3e69b
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 756 additions and 491 deletions
  1. +6
    -5
      src/layer/vulkan/convolution_vulkan.cpp
  2. +148
    -96
      src/layer/vulkan/shader/convolution_3x3s1d1_winograd43_transform_input.comp
  3. +102
    -56
      src/layer/vulkan/shader/convolution_3x3s1d1_winograd43_transform_output.comp
  4. +148
    -96
      src/layer/vulkan/shader/convolution_pack4_3x3s1d1_winograd43_transform_input.comp
  5. +102
    -56
      src/layer/vulkan/shader/convolution_pack4_3x3s1d1_winograd43_transform_output.comp
  6. +148
    -96
      src/layer/vulkan/shader/convolution_pack8_3x3s1d1_winograd43_transform_input.comp
  7. +102
    -56
      src/layer/vulkan/shader/convolution_pack8_3x3s1d1_winograd43_transform_output.comp
  8. +0
    -10
      tests/test_convolution.cpp
  9. +0
    -10
      tests/test_convolution_1.cpp
  10. +0
    -10
      tests/test_convolution_2.cpp

+ 6
- 5
src/layer/vulkan/convolution_vulkan.cpp View File

@@ -188,13 +188,14 @@ int Convolution_vulkan::create_pipeline(const Option& _opt)
Mat weight_data_tm;
weight_data_tm.create(6 * 6, num_input, num_output);

const float sq2 = 1.41421356237f;
const float ktm[6][3] = {
{1.0f, 0.0f, 0.0f},
{-2.0f / 3, -2.0f / 3, -2.0f / 3},
{-2.0f / 3, 2.0f / 3, -2.0f / 3},
{1.0f / 6, 1.0f / 3, 2.0f / 3},
{1.0f / 6, -1.0f / 3, 2.0f / 3},
{0.0f, 0.0f, 4.0f}
{-2.0f / 3, -sq2 / 3, -1.0f / 3},
{-2.0f / 3, sq2 / 3, -1.0f / 3},
{1.0f / 6, sq2 / 6, 1.0f / 3},
{1.0f / 6, -sq2 / 6, 1.0f / 3},
{0.0f, 0.0f, 1.0f}
};

#pragma omp parallel for num_threads(opt.num_threads)


+ 148
- 96
src/layer/vulkan/shader/convolution_3x3s1d1_winograd43_transform_input.comp View File

@@ -156,106 +156,158 @@ void main()
afp v55 = sy + 5 < psc(h) && sx + 5 < psc(w) ? buffer_ld1(bottom_blob_data, v_offset45.y + 5) : afp(0.f);
#endif

#define sq2 1.41421356237
#define sq2_d2 1.41421356237/2

// const float itm[6][6] = {
// {1.0f, 0.0f, -1.25f, 0.0f, 0.25f, 0.0f},
// {0.0f,-1.0f, -1.0f, 0.25f, 0.25f, 0.0f},
// {0.0f, 1.0f, -1.0f,-0.25f, 0.25f, 0.0f},
// {0.0f,-0.5f, -0.25f, 0.5f, 0.25f, 0.0f},
// {0.0f, 0.5f, -0.25f,-0.5f, 0.25f, 0.0f},
// {0.0f, 1.0f, 0.0f,-1.25f, 0.0f, 0.25f}
// {1.0f, 0.0f, -2.5f, 0.0f, 1.0f, 0.0f},
// {0.0f, -sq2, -2.0f, sq2/2, 1.0f, 0.0f},
// {0.0f, sq2, -2.0f, sq2/2, 1.0f, 0.0f},
// {0.0f, -sq2/2, -0.5f, sq2, 1.0f, 0.0f},
// {0.0f, sq2/2, -0.5f, -sq2, 1.0f, 0.0f},
// {0.0f, 1.0f, 0.0f, -2.5f, 0.0f, 1.0f}
// };

// 0 = 1 * r00 - 1.25 * r02 + 0.25 * r04
// 1 = -1 * (r01 + r02) + 0.25 * (r04 + r03)
// 2 = 1 * (r01 - r02) + 0.25 * (r04 - r03)
// 3 = -0.5 * (r01 - r03) + 0.25 * (r04 - r02)
// 4 = 0.5 * (r01 - r03) + 0.25 * (r04 - r02)
// 5 = 1 * r01 - 1.25 * r03 + 0.25 * r05

// implicit transpose
afp m00 = v00 - v02 * afp(1.25) + v04 * afp(0.25);
afp m01 = v10 - v12 * afp(1.25) + v14 * afp(0.25);
afp m02 = v20 - v22 * afp(1.25) + v24 * afp(0.25);
afp m03 = v30 - v32 * afp(1.25) + v34 * afp(0.25);
afp m04 = v40 - v42 * afp(1.25) + v44 * afp(0.25);
afp m05 = v50 - v52 * afp(1.25) + v54 * afp(0.25);

afp m10 = (v04 + v03) * afp(0.25) - (v01 + v02);
afp m11 = (v14 + v13) * afp(0.25) - (v11 + v12);
afp m12 = (v24 + v23) * afp(0.25) - (v21 + v22);
afp m13 = (v34 + v33) * afp(0.25) - (v31 + v32);
afp m14 = (v44 + v43) * afp(0.25) - (v41 + v42);
afp m15 = (v54 + v53) * afp(0.25) - (v51 + v52);

afp m20 = (v04 - v03) * afp(0.25) + (v01 - v02);
afp m21 = (v14 - v13) * afp(0.25) + (v11 - v12);
afp m22 = (v24 - v23) * afp(0.25) + (v21 - v22);
afp m23 = (v34 - v33) * afp(0.25) + (v31 - v32);
afp m24 = (v44 - v43) * afp(0.25) + (v41 - v42);
afp m25 = (v54 - v53) * afp(0.25) + (v51 - v52);

afp m30 = (v04 - v02) * afp(0.25) - (v01 - v03) * afp(0.5);
afp m31 = (v14 - v12) * afp(0.25) - (v11 - v13) * afp(0.5);
afp m32 = (v24 - v22) * afp(0.25) - (v21 - v23) * afp(0.5);
afp m33 = (v34 - v32) * afp(0.25) - (v31 - v33) * afp(0.5);
afp m34 = (v44 - v42) * afp(0.25) - (v41 - v43) * afp(0.5);
afp m35 = (v54 - v52) * afp(0.25) - (v51 - v53) * afp(0.5);

afp m40 = (v04 - v02) * afp(0.25) + (v01 - v03) * afp(0.5);
afp m41 = (v14 - v12) * afp(0.25) + (v11 - v13) * afp(0.5);
afp m42 = (v24 - v22) * afp(0.25) + (v21 - v23) * afp(0.5);
afp m43 = (v34 - v32) * afp(0.25) + (v31 - v33) * afp(0.5);
afp m44 = (v44 - v42) * afp(0.25) + (v41 - v43) * afp(0.5);
afp m45 = (v54 - v52) * afp(0.25) + (v51 - v53) * afp(0.5);

afp m50 = v01 - v03 * afp(1.25) + v05 * afp(0.25);
afp m51 = v11 - v13 * afp(1.25) + v15 * afp(0.25);
afp m52 = v21 - v23 * afp(1.25) + v25 * afp(0.25);
afp m53 = v31 - v33 * afp(1.25) + v35 * afp(0.25);
afp m54 = v41 - v43 * afp(1.25) + v45 * afp(0.25);
afp m55 = v51 - v53 * afp(1.25) + v55 * afp(0.25);

v00 = m00 - m02 * afp(1.25) + m04 * afp(0.25);
v10 = m10 - m12 * afp(1.25) + m14 * afp(0.25);
v20 = m20 - m22 * afp(1.25) + m24 * afp(0.25);
v30 = m30 - m32 * afp(1.25) + m34 * afp(0.25);
v40 = m40 - m42 * afp(1.25) + m44 * afp(0.25);
v50 = m50 - m52 * afp(1.25) + m54 * afp(0.25);

v01 = (m04 + m03) * afp(0.25) - (m01 + m02);
v11 = (m14 + m13) * afp(0.25) - (m11 + m12);
v21 = (m24 + m23) * afp(0.25) - (m21 + m22);
v31 = (m34 + m33) * afp(0.25) - (m31 + m32);
v41 = (m44 + m43) * afp(0.25) - (m41 + m42);
v51 = (m54 + m53) * afp(0.25) - (m51 + m52);

v02 = (m04 - m03) * afp(0.25) + (m01 - m02);
v12 = (m14 - m13) * afp(0.25) + (m11 - m12);
v22 = (m24 - m23) * afp(0.25) + (m21 - m22);
v32 = (m34 - m33) * afp(0.25) + (m31 - m32);
v42 = (m44 - m43) * afp(0.25) + (m41 - m42);
v52 = (m54 - m53) * afp(0.25) + (m51 - m52);

v03 = (m04 - m02) * afp(0.25) - (m01 - m03) * afp(0.5);
v13 = (m14 - m12) * afp(0.25) - (m11 - m13) * afp(0.5);
v23 = (m24 - m22) * afp(0.25) - (m21 - m23) * afp(0.5);
v33 = (m34 - m32) * afp(0.25) - (m31 - m33) * afp(0.5);
v43 = (m44 - m42) * afp(0.25) - (m41 - m43) * afp(0.5);
v53 = (m54 - m52) * afp(0.25) - (m51 - m53) * afp(0.5);

v04 = (m04 - m02) * afp(0.25) + (m01 - m03) * afp(0.5);
v14 = (m14 - m12) * afp(0.25) + (m11 - m13) * afp(0.5);
v24 = (m24 - m22) * afp(0.25) + (m21 - m23) * afp(0.5);
v34 = (m34 - m32) * afp(0.25) + (m31 - m33) * afp(0.5);
v44 = (m44 - m42) * afp(0.25) + (m41 - m43) * afp(0.5);
v54 = (m54 - m52) * afp(0.25) + (m51 - m53) * afp(0.5);

v05 = m01 - m03 * afp(1.25) + m05 * afp(0.25);
v15 = m11 - m13 * afp(1.25) + m15 * afp(0.25);
v25 = m21 - m23 * afp(1.25) + m25 * afp(0.25);
v35 = m31 - m33 * afp(1.25) + m35 * afp(0.25);
v45 = m41 - m43 * afp(1.25) + m45 * afp(0.25);
v55 = m51 - m53 * afp(1.25) + m55 * afp(0.25);
afp m00 = v00 - v02 * afp(2.5) + v04;
afp m01 = v10 - v12 * afp(2.5) + v14;
afp m02 = v20 - v22 * afp(2.5) + v24;
afp m03 = v30 - v32 * afp(2.5) + v34;
afp m04 = v40 - v42 * afp(2.5) + v44;
afp m05 = v50 - v52 * afp(2.5) + v54;

afp s0 = v03 * afp(sq2_d2) - v01 * afp(sq2);
afp s1 = v13 * afp(sq2_d2) - v11 * afp(sq2);
afp s2 = v23 * afp(sq2_d2) - v21 * afp(sq2);
afp s3 = v33 * afp(sq2_d2) - v31 * afp(sq2);
afp s4 = v43 * afp(sq2_d2) - v41 * afp(sq2);
afp s5 = v53 * afp(sq2_d2) - v51 * afp(sq2);

afp t0 = v04 - v02 * afp(2);
afp t1 = v14 - v12 * afp(2);
afp t2 = v24 - v22 * afp(2);
afp t3 = v34 - v32 * afp(2);
afp t4 = v44 - v42 * afp(2);
afp t5 = v54 - v52 * afp(2);

afp m10 = t0 + s0;
afp m11 = t1 + s1;
afp m12 = t2 + s2;
afp m13 = t3 + s3;
afp m14 = t4 + s4;
afp m15 = t5 + s5;

afp m20 = t0 - s0;
afp m21 = t1 - s1;
afp m22 = t2 - s2;
afp m23 = t3 - s3;
afp m24 = t4 - s4;
afp m25 = t5 - s5;

s0 = v03 * afp(sq2) - v01 * afp(sq2_d2);
s1 = v13 * afp(sq2) - v11 * afp(sq2_d2);
s2 = v23 * afp(sq2) - v21 * afp(sq2_d2);
s3 = v33 * afp(sq2) - v31 * afp(sq2_d2);
s4 = v43 * afp(sq2) - v41 * afp(sq2_d2);
s5 = v53 * afp(sq2) - v51 * afp(sq2_d2);

t0 = v04 - v02 * afp(0.5);
t1 = v14 - v12 * afp(0.5);
t2 = v24 - v22 * afp(0.5);
t3 = v34 - v32 * afp(0.5);
t4 = v44 - v42 * afp(0.5);
t5 = v54 - v52 * afp(0.5);

afp m30 = t0 + s0;
afp m31 = t1 + s1;
afp m32 = t2 + s2;
afp m33 = t3 + s3;
afp m34 = t4 + s4;
afp m35 = t5 + s5;

afp m40 = t0 - s0;
afp m41 = t1 - s1;
afp m42 = t2 - s2;
afp m43 = t3 - s3;
afp m44 = t4 - s4;
afp m45 = t5 - s5;

afp m50 = v01 - v03 * afp(2.5) + v05;
afp m51 = v11 - v13 * afp(2.5) + v15;
afp m52 = v21 - v23 * afp(2.5) + v25;
afp m53 = v31 - v33 * afp(2.5) + v35;
afp m54 = v41 - v43 * afp(2.5) + v45;
afp m55 = v51 - v53 * afp(2.5) + v55;

v00 = m00 - m02 * afp(2.5) + m04;
v10 = m10 - m12 * afp(2.5) + m14;
v20 = m20 - m22 * afp(2.5) + m24;
v30 = m30 - m32 * afp(2.5) + m34;
v40 = m40 - m42 * afp(2.5) + m44;
v50 = m50 - m52 * afp(2.5) + m54;

s0 = m03 * afp(sq2_d2) - m01 * afp(sq2);
s1 = m13 * afp(sq2_d2) - m11 * afp(sq2);
s2 = m23 * afp(sq2_d2) - m21 * afp(sq2);
s3 = m33 * afp(sq2_d2) - m31 * afp(sq2);
s4 = m43 * afp(sq2_d2) - m41 * afp(sq2);
s5 = m53 * afp(sq2_d2) - m51 * afp(sq2);

t0 = m04 - m02 * afp(2);
t1 = m14 - m12 * afp(2);
t2 = m24 - m22 * afp(2);
t3 = m34 - m32 * afp(2);
t4 = m44 - m42 * afp(2);
t5 = m54 - m52 * afp(2);

v01 = t0 + s0;
v11 = t1 + s1;
v21 = t2 + s2;
v31 = t3 + s3;
v41 = t4 + s4;
v51 = t5 + s5;

v02 = t0 - s0;
v12 = t1 - s1;
v22 = t2 - s2;
v32 = t3 - s3;
v42 = t4 - s4;
v52 = t5 - s5;

s0 = m03 * afp(sq2) - m01 * afp(sq2_d2);
s1 = m13 * afp(sq2) - m11 * afp(sq2_d2);
s2 = m23 * afp(sq2) - m21 * afp(sq2_d2);
s3 = m33 * afp(sq2) - m31 * afp(sq2_d2);
s4 = m43 * afp(sq2) - m41 * afp(sq2_d2);
s5 = m53 * afp(sq2) - m51 * afp(sq2_d2);

t0 = m04 - m02 * afp(0.5);
t1 = m14 - m12 * afp(0.5);
t2 = m24 - m22 * afp(0.5);
t3 = m34 - m32 * afp(0.5);
t4 = m44 - m42 * afp(0.5);
t5 = m54 - m52 * afp(0.5);

v03 = t0 + s0;
v13 = t1 + s1;
v23 = t2 + s2;
v33 = t3 + s3;
v43 = t4 + s4;
v53 = t5 + s5;

v04 = t0 - s0;
v14 = t1 - s1;
v24 = t2 - s2;
v34 = t3 - s3;
v44 = t4 - s4;
v54 = t5 - s5;

v05 = m01 - m03 * afp(2.5) + m05;
v15 = m11 - m13 * afp(2.5) + m15;
v25 = m21 - m23 * afp(2.5) + m25;
v35 = m31 - m33 * afp(2.5) + m35;
v45 = m41 - m43 * afp(2.5) + m45;
v55 = m51 - m53 * afp(2.5) + m55;

// store 36
#if NCNN_image_shader


+ 102
- 56
src/layer/vulkan/shader/convolution_3x3s1d1_winograd43_transform_output.comp View File

@@ -153,66 +153,112 @@ void main()
afp v55 = buffer_ld1(top_tm_blob_data, v_tm_offset + 35 * psc(cstep));
#endif

#define sq2 1.41421356237
#define sq2_d4 1.41421356237/4

// const float otm[4][6] = {
// {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f},
// {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f},
// {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f},
// {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f}
// {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f},
// {0.0f, sq2/2, -sq2/2, sq2, -sq2, 0.0f},
// {0.0f, 0.5f, 0.5f, 2.0f, 2.0f, 0.0f},
// {0.0f, sq2/4, -sq2/4, sq2*2, -sq2*2, 1.0f}
// };

// 0 = r00 + (r01 + r02) + (r03 + r04)
// 1 = (r01 - r02) + (r03 - r04) * 2
// 2 = (r01 + r02) + (r03 + r04) * 4
// 3 = r05 + (r01 - r02) + (r03 - r04) * 8

// implicit transpose
afp m00 = v00 + (v01 + v02) + (v03 + v04);
afp m01 = v10 + (v11 + v12) + (v13 + v14);
afp m02 = v20 + (v21 + v22) + (v23 + v24);
afp m03 = v30 + (v31 + v32) + (v33 + v34);
afp m04 = v40 + (v41 + v42) + (v43 + v44);
afp m05 = v50 + (v51 + v52) + (v53 + v54);

afp m10 = (v01 - v02) + (v03 - v04) * afp(2);
afp m11 = (v11 - v12) + (v13 - v14) * afp(2);
afp m12 = (v21 - v22) + (v23 - v24) * afp(2);
afp m13 = (v31 - v32) + (v33 - v34) * afp(2);
afp m14 = (v41 - v42) + (v43 - v44) * afp(2);
afp m15 = (v51 - v52) + (v53 - v54) * afp(2);

afp m20 = (v01 + v02) + (v03 + v04) * afp(4);
afp m21 = (v11 + v12) + (v13 + v14) * afp(4);
afp m22 = (v21 + v22) + (v23 + v24) * afp(4);
afp m23 = (v31 + v32) + (v33 + v34) * afp(4);
afp m24 = (v41 + v42) + (v43 + v44) * afp(4);
afp m25 = (v51 + v52) + (v53 + v54) * afp(4);

afp m30 = v05 + (v01 - v02) + (v03 - v04) * afp(8);
afp m31 = v15 + (v11 - v12) + (v13 - v14) * afp(8);
afp m32 = v25 + (v21 - v22) + (v23 - v24) * afp(8);
afp m33 = v35 + (v31 - v32) + (v33 - v34) * afp(8);
afp m34 = v45 + (v41 - v42) + (v43 - v44) * afp(8);
afp m35 = v55 + (v51 - v52) + (v53 - v54) * afp(8);

v00 = m00 + (m01 + m02) + (m03 + m04);
v10 = m10 + (m11 + m12) + (m13 + m14);
v20 = m20 + (m21 + m22) + (m23 + m24);
v30 = m30 + (m31 + m32) + (m33 + m34);

v01 = (m01 - m02) + (m03 - m04) * afp(2);
v11 = (m11 - m12) + (m13 - m14) * afp(2);
v21 = (m21 - m22) + (m23 - m24) * afp(2);
v31 = (m31 - m32) + (m33 - m34) * afp(2);

v02 = (m01 + m02) + (m03 + m04) * afp(4);
v12 = (m11 + m12) + (m13 + m14) * afp(4);
v22 = (m21 + m22) + (m23 + m24) * afp(4);
v32 = (m31 + m32) + (m33 + m34) * afp(4);

v03 = m05 + (m01 - m02) + (m03 - m04) * afp(8);
v13 = m15 + (m11 - m12) + (m13 - m14) * afp(8);
v23 = m25 + (m21 - m22) + (m23 - m24) * afp(8);
v33 = m35 + (m31 - m32) + (m33 - m34) * afp(8);
afp s0 = (v01 + v02) * afp(0.5);
afp s1 = (v11 + v12) * afp(0.5);
afp s2 = (v21 + v22) * afp(0.5);
afp s3 = (v31 + v32) * afp(0.5);
afp s4 = (v41 + v42) * afp(0.5);
afp s5 = (v51 + v52) * afp(0.5);

afp t0 = (v01 - v02) * afp(sq2_d4);
afp t1 = (v11 - v12) * afp(sq2_d4);
afp t2 = (v21 - v22) * afp(sq2_d4);
afp t3 = (v31 - v32) * afp(sq2_d4);
afp t4 = (v41 - v42) * afp(sq2_d4);
afp t5 = (v51 - v52) * afp(sq2_d4);

afp u0 = v03 + v04;
afp u1 = v13 + v14;
afp u2 = v23 + v24;
afp u3 = v33 + v34;
afp u4 = v43 + v44;
afp u5 = v53 + v54;

afp v0 = (v03 - v04) * afp(sq2);
afp v1 = (v13 - v14) * afp(sq2);
afp v2 = (v23 - v24) * afp(sq2);
afp v3 = (v33 - v34) * afp(sq2);
afp v4 = (v43 - v44) * afp(sq2);
afp v5 = (v53 - v54) * afp(sq2);

afp m00 = v00 + s0 + s0 + u0;
afp m01 = v10 + s1 + s1 + u1;
afp m02 = v20 + s2 + s2 + u2;
afp m03 = v30 + s3 + s3 + u3;
afp m04 = v40 + s4 + s4 + u4;
afp m05 = v50 + s5 + s5 + u5;

afp m10 = t0 + t0 + v0;
afp m11 = t1 + t1 + v1;
afp m12 = t2 + t2 + v2;
afp m13 = t3 + t3 + v3;
afp m14 = t4 + t4 + v4;
afp m15 = t5 + t5 + v5;

afp m20 = s0 + u0 + u0;
afp m21 = s1 + u1 + u1;
afp m22 = s2 + u2 + u2;
afp m23 = s3 + u3 + u3;
afp m24 = s4 + u4 + u4;
afp m25 = s5 + u5 + u5;

afp m30 = v05 + t0 + v0 + v0;
afp m31 = v15 + t1 + v1 + v1;
afp m32 = v25 + t2 + v2 + v2;
afp m33 = v35 + t3 + v3 + v3;
afp m34 = v45 + t4 + v4 + v4;
afp m35 = v55 + t5 + v5 + v5;

s0 = (m01 + m02) * afp(0.5);
s1 = (m11 + m12) * afp(0.5);
s2 = (m21 + m22) * afp(0.5);
s3 = (m31 + m32) * afp(0.5);

t0 = (m01 - m02) * afp(sq2_d4);
t1 = (m11 - m12) * afp(sq2_d4);
t2 = (m21 - m22) * afp(sq2_d4);
t3 = (m31 - m32) * afp(sq2_d4);

u0 = m03 + m04;
u1 = m13 + m14;
u2 = m23 + m24;
u3 = m33 + m34;

v0 = (m03 - m04) * afp(sq2);
v1 = (m13 - m14) * afp(sq2);
v2 = (m23 - m24) * afp(sq2);
v3 = (m33 - m34) * afp(sq2);

v00 = m00 + s0 + s0 + u0;
v10 = m10 + s1 + s1 + u1;
v20 = m20 + s2 + s2 + u2;
v30 = m30 + s3 + s3 + u3;

v01 = t0 + t0 + v0;
v11 = t1 + t1 + v1;
v21 = t2 + t2 + v2;
v31 = t3 + t3 + v3;

v02 = s0 + u0 + u0;
v12 = s1 + u1 + u1;
v22 = s2 + u2 + u2;
v32 = s3 + u3 + u3;

v03 = m05 + t0 + v0 + v0;
v13 = m15 + t1 + v1 + v1;
v23 = m25 + t2 + v2 + v2;
v33 = m35 + t3 + v3 + v3;

if (bias_term == 1)
{


+ 148
- 96
src/layer/vulkan/shader/convolution_pack4_3x3s1d1_winograd43_transform_input.comp View File

@@ -156,106 +156,158 @@ void main()
afpvec4 v55 = sy + 5 < psc(h) && sx + 5 < psc(w) ? buffer_ld4(bottom_blob_data, v_offset45.y + 5) : afpvec4(0.f);
#endif

#define sq2 1.41421356237
#define sq2_d2 1.41421356237/2

// const float itm[6][6] = {
// {1.0f, 0.0f, -1.25f, 0.0f, 0.25f, 0.0f},
// {0.0f,-1.0f, -1.0f, 0.25f, 0.25f, 0.0f},
// {0.0f, 1.0f, -1.0f,-0.25f, 0.25f, 0.0f},
// {0.0f,-0.5f, -0.25f, 0.5f, 0.25f, 0.0f},
// {0.0f, 0.5f, -0.25f,-0.5f, 0.25f, 0.0f},
// {0.0f, 1.0f, 0.0f,-1.25f, 0.0f, 0.25f}
// {1.0f, 0.0f, -2.5f, 0.0f, 1.0f, 0.0f},
// {0.0f, -sq2, -2.0f, sq2/2, 1.0f, 0.0f},
// {0.0f, sq2, -2.0f, sq2/2, 1.0f, 0.0f},
// {0.0f, -sq2/2, -0.5f, sq2, 1.0f, 0.0f},
// {0.0f, sq2/2, -0.5f, -sq2, 1.0f, 0.0f},
// {0.0f, 1.0f, 0.0f, -2.5f, 0.0f, 1.0f}
// };

// 0 = 1 * r00 - 1.25 * r02 + 0.25 * r04
// 1 = -1 * (r01 + r02) + 0.25 * (r04 + r03)
// 2 = 1 * (r01 - r02) + 0.25 * (r04 - r03)
// 3 = -0.5 * (r01 - r03) + 0.25 * (r04 - r02)
// 4 = 0.5 * (r01 - r03) + 0.25 * (r04 - r02)
// 5 = 1 * r01 - 1.25 * r03 + 0.25 * r05

// implicit transpose
afpvec4 m00 = v00 - v02 * afp(1.25) + v04 * afp(0.25);
afpvec4 m01 = v10 - v12 * afp(1.25) + v14 * afp(0.25);
afpvec4 m02 = v20 - v22 * afp(1.25) + v24 * afp(0.25);
afpvec4 m03 = v30 - v32 * afp(1.25) + v34 * afp(0.25);
afpvec4 m04 = v40 - v42 * afp(1.25) + v44 * afp(0.25);
afpvec4 m05 = v50 - v52 * afp(1.25) + v54 * afp(0.25);

afpvec4 m10 = (v04 + v03) * afp(0.25) - (v01 + v02);
afpvec4 m11 = (v14 + v13) * afp(0.25) - (v11 + v12);
afpvec4 m12 = (v24 + v23) * afp(0.25) - (v21 + v22);
afpvec4 m13 = (v34 + v33) * afp(0.25) - (v31 + v32);
afpvec4 m14 = (v44 + v43) * afp(0.25) - (v41 + v42);
afpvec4 m15 = (v54 + v53) * afp(0.25) - (v51 + v52);

afpvec4 m20 = (v04 - v03) * afp(0.25) + (v01 - v02);
afpvec4 m21 = (v14 - v13) * afp(0.25) + (v11 - v12);
afpvec4 m22 = (v24 - v23) * afp(0.25) + (v21 - v22);
afpvec4 m23 = (v34 - v33) * afp(0.25) + (v31 - v32);
afpvec4 m24 = (v44 - v43) * afp(0.25) + (v41 - v42);
afpvec4 m25 = (v54 - v53) * afp(0.25) + (v51 - v52);

afpvec4 m30 = (v04 - v02) * afp(0.25) - (v01 - v03) * afp(0.5);
afpvec4 m31 = (v14 - v12) * afp(0.25) - (v11 - v13) * afp(0.5);
afpvec4 m32 = (v24 - v22) * afp(0.25) - (v21 - v23) * afp(0.5);
afpvec4 m33 = (v34 - v32) * afp(0.25) - (v31 - v33) * afp(0.5);
afpvec4 m34 = (v44 - v42) * afp(0.25) - (v41 - v43) * afp(0.5);
afpvec4 m35 = (v54 - v52) * afp(0.25) - (v51 - v53) * afp(0.5);

afpvec4 m40 = (v04 - v02) * afp(0.25) + (v01 - v03) * afp(0.5);
afpvec4 m41 = (v14 - v12) * afp(0.25) + (v11 - v13) * afp(0.5);
afpvec4 m42 = (v24 - v22) * afp(0.25) + (v21 - v23) * afp(0.5);
afpvec4 m43 = (v34 - v32) * afp(0.25) + (v31 - v33) * afp(0.5);
afpvec4 m44 = (v44 - v42) * afp(0.25) + (v41 - v43) * afp(0.5);
afpvec4 m45 = (v54 - v52) * afp(0.25) + (v51 - v53) * afp(0.5);

afpvec4 m50 = v01 - v03 * afp(1.25) + v05 * afp(0.25);
afpvec4 m51 = v11 - v13 * afp(1.25) + v15 * afp(0.25);
afpvec4 m52 = v21 - v23 * afp(1.25) + v25 * afp(0.25);
afpvec4 m53 = v31 - v33 * afp(1.25) + v35 * afp(0.25);
afpvec4 m54 = v41 - v43 * afp(1.25) + v45 * afp(0.25);
afpvec4 m55 = v51 - v53 * afp(1.25) + v55 * afp(0.25);

v00 = m00 - m02 * afp(1.25) + m04 * afp(0.25);
v10 = m10 - m12 * afp(1.25) + m14 * afp(0.25);
v20 = m20 - m22 * afp(1.25) + m24 * afp(0.25);
v30 = m30 - m32 * afp(1.25) + m34 * afp(0.25);
v40 = m40 - m42 * afp(1.25) + m44 * afp(0.25);
v50 = m50 - m52 * afp(1.25) + m54 * afp(0.25);

v01 = (m04 + m03) * afp(0.25) - (m01 + m02);
v11 = (m14 + m13) * afp(0.25) - (m11 + m12);
v21 = (m24 + m23) * afp(0.25) - (m21 + m22);
v31 = (m34 + m33) * afp(0.25) - (m31 + m32);
v41 = (m44 + m43) * afp(0.25) - (m41 + m42);
v51 = (m54 + m53) * afp(0.25) - (m51 + m52);

v02 = (m04 - m03) * afp(0.25) + (m01 - m02);
v12 = (m14 - m13) * afp(0.25) + (m11 - m12);
v22 = (m24 - m23) * afp(0.25) + (m21 - m22);
v32 = (m34 - m33) * afp(0.25) + (m31 - m32);
v42 = (m44 - m43) * afp(0.25) + (m41 - m42);
v52 = (m54 - m53) * afp(0.25) + (m51 - m52);

v03 = (m04 - m02) * afp(0.25) - (m01 - m03) * afp(0.5);
v13 = (m14 - m12) * afp(0.25) - (m11 - m13) * afp(0.5);
v23 = (m24 - m22) * afp(0.25) - (m21 - m23) * afp(0.5);
v33 = (m34 - m32) * afp(0.25) - (m31 - m33) * afp(0.5);
v43 = (m44 - m42) * afp(0.25) - (m41 - m43) * afp(0.5);
v53 = (m54 - m52) * afp(0.25) - (m51 - m53) * afp(0.5);

v04 = (m04 - m02) * afp(0.25) + (m01 - m03) * afp(0.5);
v14 = (m14 - m12) * afp(0.25) + (m11 - m13) * afp(0.5);
v24 = (m24 - m22) * afp(0.25) + (m21 - m23) * afp(0.5);
v34 = (m34 - m32) * afp(0.25) + (m31 - m33) * afp(0.5);
v44 = (m44 - m42) * afp(0.25) + (m41 - m43) * afp(0.5);
v54 = (m54 - m52) * afp(0.25) + (m51 - m53) * afp(0.5);

v05 = m01 - m03 * afp(1.25) + m05 * afp(0.25);
v15 = m11 - m13 * afp(1.25) + m15 * afp(0.25);
v25 = m21 - m23 * afp(1.25) + m25 * afp(0.25);
v35 = m31 - m33 * afp(1.25) + m35 * afp(0.25);
v45 = m41 - m43 * afp(1.25) + m45 * afp(0.25);
v55 = m51 - m53 * afp(1.25) + m55 * afp(0.25);
afpvec4 m00 = v00 - v02 * afp(2.5) + v04;
afpvec4 m01 = v10 - v12 * afp(2.5) + v14;
afpvec4 m02 = v20 - v22 * afp(2.5) + v24;
afpvec4 m03 = v30 - v32 * afp(2.5) + v34;
afpvec4 m04 = v40 - v42 * afp(2.5) + v44;
afpvec4 m05 = v50 - v52 * afp(2.5) + v54;

afpvec4 s0 = v03 * afp(sq2_d2) - v01 * afp(sq2);
afpvec4 s1 = v13 * afp(sq2_d2) - v11 * afp(sq2);
afpvec4 s2 = v23 * afp(sq2_d2) - v21 * afp(sq2);
afpvec4 s3 = v33 * afp(sq2_d2) - v31 * afp(sq2);
afpvec4 s4 = v43 * afp(sq2_d2) - v41 * afp(sq2);
afpvec4 s5 = v53 * afp(sq2_d2) - v51 * afp(sq2);

afpvec4 t0 = v04 - v02 * afp(2);
afpvec4 t1 = v14 - v12 * afp(2);
afpvec4 t2 = v24 - v22 * afp(2);
afpvec4 t3 = v34 - v32 * afp(2);
afpvec4 t4 = v44 - v42 * afp(2);
afpvec4 t5 = v54 - v52 * afp(2);

afpvec4 m10 = t0 + s0;
afpvec4 m11 = t1 + s1;
afpvec4 m12 = t2 + s2;
afpvec4 m13 = t3 + s3;
afpvec4 m14 = t4 + s4;
afpvec4 m15 = t5 + s5;

afpvec4 m20 = t0 - s0;
afpvec4 m21 = t1 - s1;
afpvec4 m22 = t2 - s2;
afpvec4 m23 = t3 - s3;
afpvec4 m24 = t4 - s4;
afpvec4 m25 = t5 - s5;

s0 = v03 * afp(sq2) - v01 * afp(sq2_d2);
s1 = v13 * afp(sq2) - v11 * afp(sq2_d2);
s2 = v23 * afp(sq2) - v21 * afp(sq2_d2);
s3 = v33 * afp(sq2) - v31 * afp(sq2_d2);
s4 = v43 * afp(sq2) - v41 * afp(sq2_d2);
s5 = v53 * afp(sq2) - v51 * afp(sq2_d2);

t0 = v04 - v02 * afp(0.5);
t1 = v14 - v12 * afp(0.5);
t2 = v24 - v22 * afp(0.5);
t3 = v34 - v32 * afp(0.5);
t4 = v44 - v42 * afp(0.5);
t5 = v54 - v52 * afp(0.5);

afpvec4 m30 = t0 + s0;
afpvec4 m31 = t1 + s1;
afpvec4 m32 = t2 + s2;
afpvec4 m33 = t3 + s3;
afpvec4 m34 = t4 + s4;
afpvec4 m35 = t5 + s5;

afpvec4 m40 = t0 - s0;
afpvec4 m41 = t1 - s1;
afpvec4 m42 = t2 - s2;
afpvec4 m43 = t3 - s3;
afpvec4 m44 = t4 - s4;
afpvec4 m45 = t5 - s5;

afpvec4 m50 = v01 - v03 * afp(2.5) + v05;
afpvec4 m51 = v11 - v13 * afp(2.5) + v15;
afpvec4 m52 = v21 - v23 * afp(2.5) + v25;
afpvec4 m53 = v31 - v33 * afp(2.5) + v35;
afpvec4 m54 = v41 - v43 * afp(2.5) + v45;
afpvec4 m55 = v51 - v53 * afp(2.5) + v55;

v00 = m00 - m02 * afp(2.5) + m04;
v10 = m10 - m12 * afp(2.5) + m14;
v20 = m20 - m22 * afp(2.5) + m24;
v30 = m30 - m32 * afp(2.5) + m34;
v40 = m40 - m42 * afp(2.5) + m44;
v50 = m50 - m52 * afp(2.5) + m54;

s0 = m03 * afp(sq2_d2) - m01 * afp(sq2);
s1 = m13 * afp(sq2_d2) - m11 * afp(sq2);
s2 = m23 * afp(sq2_d2) - m21 * afp(sq2);
s3 = m33 * afp(sq2_d2) - m31 * afp(sq2);
s4 = m43 * afp(sq2_d2) - m41 * afp(sq2);
s5 = m53 * afp(sq2_d2) - m51 * afp(sq2);

t0 = m04 - m02 * afp(2);
t1 = m14 - m12 * afp(2);
t2 = m24 - m22 * afp(2);
t3 = m34 - m32 * afp(2);
t4 = m44 - m42 * afp(2);
t5 = m54 - m52 * afp(2);

v01 = t0 + s0;
v11 = t1 + s1;
v21 = t2 + s2;
v31 = t3 + s3;
v41 = t4 + s4;
v51 = t5 + s5;

v02 = t0 - s0;
v12 = t1 - s1;
v22 = t2 - s2;
v32 = t3 - s3;
v42 = t4 - s4;
v52 = t5 - s5;

s0 = m03 * afp(sq2) - m01 * afp(sq2_d2);
s1 = m13 * afp(sq2) - m11 * afp(sq2_d2);
s2 = m23 * afp(sq2) - m21 * afp(sq2_d2);
s3 = m33 * afp(sq2) - m31 * afp(sq2_d2);
s4 = m43 * afp(sq2) - m41 * afp(sq2_d2);
s5 = m53 * afp(sq2) - m51 * afp(sq2_d2);

t0 = m04 - m02 * afp(0.5);
t1 = m14 - m12 * afp(0.5);
t2 = m24 - m22 * afp(0.5);
t3 = m34 - m32 * afp(0.5);
t4 = m44 - m42 * afp(0.5);
t5 = m54 - m52 * afp(0.5);

v03 = t0 + s0;
v13 = t1 + s1;
v23 = t2 + s2;
v33 = t3 + s3;
v43 = t4 + s4;
v53 = t5 + s5;

v04 = t0 - s0;
v14 = t1 - s1;
v24 = t2 - s2;
v34 = t3 - s3;
v44 = t4 - s4;
v54 = t5 - s5;

v05 = m01 - m03 * afp(2.5) + m05;
v15 = m11 - m13 * afp(2.5) + m15;
v25 = m21 - m23 * afp(2.5) + m25;
v35 = m31 - m33 * afp(2.5) + m35;
v45 = m41 - m43 * afp(2.5) + m45;
v55 = m51 - m53 * afp(2.5) + m55;

// store 36
#if NCNN_image_shader


+ 102
- 56
src/layer/vulkan/shader/convolution_pack4_3x3s1d1_winograd43_transform_output.comp View File

@@ -153,66 +153,112 @@ void main()
afpvec4 v55 = buffer_ld4(top_tm_blob_data, v_tm_offset + 35 * psc(cstep));
#endif

#define sq2 1.41421356237
#define sq2_d4 1.41421356237/4

// const float otm[4][6] = {
// {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f},
// {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f},
// {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f},
// {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f}
// {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f},
// {0.0f, sq2/2, -sq2/2, sq2, -sq2, 0.0f},
// {0.0f, 0.5f, 0.5f, 2.0f, 2.0f, 0.0f},
// {0.0f, sq2/4, -sq2/4, sq2*2, -sq2*2, 1.0f}
// };

// 0 = r00 + (r01 + r02) + (r03 + r04)
// 1 = (r01 - r02) + (r03 - r04) * 2
// 2 = (r01 + r02) + (r03 + r04) * 4
// 3 = r05 + (r01 - r02) + (r03 - r04) * 8

// implicit transpose
afpvec4 m00 = v00 + (v01 + v02) + (v03 + v04);
afpvec4 m01 = v10 + (v11 + v12) + (v13 + v14);
afpvec4 m02 = v20 + (v21 + v22) + (v23 + v24);
afpvec4 m03 = v30 + (v31 + v32) + (v33 + v34);
afpvec4 m04 = v40 + (v41 + v42) + (v43 + v44);
afpvec4 m05 = v50 + (v51 + v52) + (v53 + v54);

afpvec4 m10 = (v01 - v02) + (v03 - v04) * afp(2);
afpvec4 m11 = (v11 - v12) + (v13 - v14) * afp(2);
afpvec4 m12 = (v21 - v22) + (v23 - v24) * afp(2);
afpvec4 m13 = (v31 - v32) + (v33 - v34) * afp(2);
afpvec4 m14 = (v41 - v42) + (v43 - v44) * afp(2);
afpvec4 m15 = (v51 - v52) + (v53 - v54) * afp(2);

afpvec4 m20 = (v01 + v02) + (v03 + v04) * afp(4);
afpvec4 m21 = (v11 + v12) + (v13 + v14) * afp(4);
afpvec4 m22 = (v21 + v22) + (v23 + v24) * afp(4);
afpvec4 m23 = (v31 + v32) + (v33 + v34) * afp(4);
afpvec4 m24 = (v41 + v42) + (v43 + v44) * afp(4);
afpvec4 m25 = (v51 + v52) + (v53 + v54) * afp(4);

afpvec4 m30 = v05 + (v01 - v02) + (v03 - v04) * afp(8);
afpvec4 m31 = v15 + (v11 - v12) + (v13 - v14) * afp(8);
afpvec4 m32 = v25 + (v21 - v22) + (v23 - v24) * afp(8);
afpvec4 m33 = v35 + (v31 - v32) + (v33 - v34) * afp(8);
afpvec4 m34 = v45 + (v41 - v42) + (v43 - v44) * afp(8);
afpvec4 m35 = v55 + (v51 - v52) + (v53 - v54) * afp(8);

v00 = m00 + (m01 + m02) + (m03 + m04);
v10 = m10 + (m11 + m12) + (m13 + m14);
v20 = m20 + (m21 + m22) + (m23 + m24);
v30 = m30 + (m31 + m32) + (m33 + m34);

v01 = (m01 - m02) + (m03 - m04) * afp(2);
v11 = (m11 - m12) + (m13 - m14) * afp(2);
v21 = (m21 - m22) + (m23 - m24) * afp(2);
v31 = (m31 - m32) + (m33 - m34) * afp(2);

v02 = (m01 + m02) + (m03 + m04) * afp(4);
v12 = (m11 + m12) + (m13 + m14) * afp(4);
v22 = (m21 + m22) + (m23 + m24) * afp(4);
v32 = (m31 + m32) + (m33 + m34) * afp(4);

v03 = m05 + (m01 - m02) + (m03 - m04) * afp(8);
v13 = m15 + (m11 - m12) + (m13 - m14) * afp(8);
v23 = m25 + (m21 - m22) + (m23 - m24) * afp(8);
v33 = m35 + (m31 - m32) + (m33 - m34) * afp(8);
afpvec4 s0 = (v01 + v02) * afp(0.5);
afpvec4 s1 = (v11 + v12) * afp(0.5);
afpvec4 s2 = (v21 + v22) * afp(0.5);
afpvec4 s3 = (v31 + v32) * afp(0.5);
afpvec4 s4 = (v41 + v42) * afp(0.5);
afpvec4 s5 = (v51 + v52) * afp(0.5);

afpvec4 t0 = (v01 - v02) * afp(sq2_d4);
afpvec4 t1 = (v11 - v12) * afp(sq2_d4);
afpvec4 t2 = (v21 - v22) * afp(sq2_d4);
afpvec4 t3 = (v31 - v32) * afp(sq2_d4);
afpvec4 t4 = (v41 - v42) * afp(sq2_d4);
afpvec4 t5 = (v51 - v52) * afp(sq2_d4);

afpvec4 u0 = v03 + v04;
afpvec4 u1 = v13 + v14;
afpvec4 u2 = v23 + v24;
afpvec4 u3 = v33 + v34;
afpvec4 u4 = v43 + v44;
afpvec4 u5 = v53 + v54;

afpvec4 v0 = (v03 - v04) * afp(sq2);
afpvec4 v1 = (v13 - v14) * afp(sq2);
afpvec4 v2 = (v23 - v24) * afp(sq2);
afpvec4 v3 = (v33 - v34) * afp(sq2);
afpvec4 v4 = (v43 - v44) * afp(sq2);
afpvec4 v5 = (v53 - v54) * afp(sq2);

afpvec4 m00 = v00 + s0 + s0 + u0;
afpvec4 m01 = v10 + s1 + s1 + u1;
afpvec4 m02 = v20 + s2 + s2 + u2;
afpvec4 m03 = v30 + s3 + s3 + u3;
afpvec4 m04 = v40 + s4 + s4 + u4;
afpvec4 m05 = v50 + s5 + s5 + u5;

afpvec4 m10 = t0 + t0 + v0;
afpvec4 m11 = t1 + t1 + v1;
afpvec4 m12 = t2 + t2 + v2;
afpvec4 m13 = t3 + t3 + v3;
afpvec4 m14 = t4 + t4 + v4;
afpvec4 m15 = t5 + t5 + v5;

afpvec4 m20 = s0 + u0 + u0;
afpvec4 m21 = s1 + u1 + u1;
afpvec4 m22 = s2 + u2 + u2;
afpvec4 m23 = s3 + u3 + u3;
afpvec4 m24 = s4 + u4 + u4;
afpvec4 m25 = s5 + u5 + u5;

afpvec4 m30 = v05 + t0 + v0 + v0;
afpvec4 m31 = v15 + t1 + v1 + v1;
afpvec4 m32 = v25 + t2 + v2 + v2;
afpvec4 m33 = v35 + t3 + v3 + v3;
afpvec4 m34 = v45 + t4 + v4 + v4;
afpvec4 m35 = v55 + t5 + v5 + v5;

s0 = (m01 + m02) * afp(0.5);
s1 = (m11 + m12) * afp(0.5);
s2 = (m21 + m22) * afp(0.5);
s3 = (m31 + m32) * afp(0.5);

t0 = (m01 - m02) * afp(sq2_d4);
t1 = (m11 - m12) * afp(sq2_d4);
t2 = (m21 - m22) * afp(sq2_d4);
t3 = (m31 - m32) * afp(sq2_d4);

u0 = m03 + m04;
u1 = m13 + m14;
u2 = m23 + m24;
u3 = m33 + m34;

v0 = (m03 - m04) * afp(sq2);
v1 = (m13 - m14) * afp(sq2);
v2 = (m23 - m24) * afp(sq2);
v3 = (m33 - m34) * afp(sq2);

v00 = m00 + s0 + s0 + u0;
v10 = m10 + s1 + s1 + u1;
v20 = m20 + s2 + s2 + u2;
v30 = m30 + s3 + s3 + u3;

v01 = t0 + t0 + v0;
v11 = t1 + t1 + v1;
v21 = t2 + t2 + v2;
v31 = t3 + t3 + v3;

v02 = s0 + u0 + u0;
v12 = s1 + u1 + u1;
v22 = s2 + u2 + u2;
v32 = s3 + u3 + u3;

v03 = m05 + t0 + v0 + v0;
v13 = m15 + t1 + v1 + v1;
v23 = m25 + t2 + v2 + v2;
v33 = m35 + t3 + v3 + v3;

if (bias_term == 1)
{


+ 148
- 96
src/layer/vulkan/shader/convolution_pack8_3x3s1d1_winograd43_transform_input.comp View File

@@ -157,106 +157,158 @@ void main()
afpvec8 v55 = sy + 5 < psc(h) && sx + 5 < psc(w) ? buffer_ld8(bottom_blob_data, v_offset45.y + 5) : afpvec8(afpvec4(0.f), afpvec4(0.f));
#endif

#define sq2 1.41421356237
#define sq2_d2 1.41421356237/2

// const float itm[6][6] = {
// {1.0f, 0.0f, -1.25f, 0.0f, 0.25f, 0.0f},
// {0.0f,-1.0f, -1.0f, 0.25f, 0.25f, 0.0f},
// {0.0f, 1.0f, -1.0f,-0.25f, 0.25f, 0.0f},
// {0.0f,-0.5f, -0.25f, 0.5f, 0.25f, 0.0f},
// {0.0f, 0.5f, -0.25f,-0.5f, 0.25f, 0.0f},
// {0.0f, 1.0f, 0.0f,-1.25f, 0.0f, 0.25f}
// {1.0f, 0.0f, -2.5f, 0.0f, 1.0f, 0.0f},
// {0.0f, -sq2, -2.0f, sq2/2, 1.0f, 0.0f},
// {0.0f, sq2, -2.0f, sq2/2, 1.0f, 0.0f},
// {0.0f, -sq2/2, -0.5f, sq2, 1.0f, 0.0f},
// {0.0f, sq2/2, -0.5f, -sq2, 1.0f, 0.0f},
// {0.0f, 1.0f, 0.0f, -2.5f, 0.0f, 1.0f}
// };

// 0 = 1 * r00 - 1.25 * r02 + 0.25 * r04
// 1 = -1 * (r01 + r02) + 0.25 * (r04 + r03)
// 2 = 1 * (r01 - r02) + 0.25 * (r04 - r03)
// 3 = -0.5 * (r01 - r03) + 0.25 * (r04 - r02)
// 4 = 0.5 * (r01 - r03) + 0.25 * (r04 - r02)
// 5 = 1 * r01 - 1.25 * r03 + 0.25 * r05

// implicit transpose
afpvec8 m00 = v00 - v02 * afp(1.25) + v04 * afp(0.25);
afpvec8 m01 = v10 - v12 * afp(1.25) + v14 * afp(0.25);
afpvec8 m02 = v20 - v22 * afp(1.25) + v24 * afp(0.25);
afpvec8 m03 = v30 - v32 * afp(1.25) + v34 * afp(0.25);
afpvec8 m04 = v40 - v42 * afp(1.25) + v44 * afp(0.25);
afpvec8 m05 = v50 - v52 * afp(1.25) + v54 * afp(0.25);

afpvec8 m10 = (v04 + v03) * afp(0.25) - (v01 + v02);
afpvec8 m11 = (v14 + v13) * afp(0.25) - (v11 + v12);
afpvec8 m12 = (v24 + v23) * afp(0.25) - (v21 + v22);
afpvec8 m13 = (v34 + v33) * afp(0.25) - (v31 + v32);
afpvec8 m14 = (v44 + v43) * afp(0.25) - (v41 + v42);
afpvec8 m15 = (v54 + v53) * afp(0.25) - (v51 + v52);

afpvec8 m20 = (v04 - v03) * afp(0.25) + (v01 - v02);
afpvec8 m21 = (v14 - v13) * afp(0.25) + (v11 - v12);
afpvec8 m22 = (v24 - v23) * afp(0.25) + (v21 - v22);
afpvec8 m23 = (v34 - v33) * afp(0.25) + (v31 - v32);
afpvec8 m24 = (v44 - v43) * afp(0.25) + (v41 - v42);
afpvec8 m25 = (v54 - v53) * afp(0.25) + (v51 - v52);

afpvec8 m30 = (v04 - v02) * afp(0.25) - (v01 - v03) * afp(0.5);
afpvec8 m31 = (v14 - v12) * afp(0.25) - (v11 - v13) * afp(0.5);
afpvec8 m32 = (v24 - v22) * afp(0.25) - (v21 - v23) * afp(0.5);
afpvec8 m33 = (v34 - v32) * afp(0.25) - (v31 - v33) * afp(0.5);
afpvec8 m34 = (v44 - v42) * afp(0.25) - (v41 - v43) * afp(0.5);
afpvec8 m35 = (v54 - v52) * afp(0.25) - (v51 - v53) * afp(0.5);

afpvec8 m40 = (v04 - v02) * afp(0.25) + (v01 - v03) * afp(0.5);
afpvec8 m41 = (v14 - v12) * afp(0.25) + (v11 - v13) * afp(0.5);
afpvec8 m42 = (v24 - v22) * afp(0.25) + (v21 - v23) * afp(0.5);
afpvec8 m43 = (v34 - v32) * afp(0.25) + (v31 - v33) * afp(0.5);
afpvec8 m44 = (v44 - v42) * afp(0.25) + (v41 - v43) * afp(0.5);
afpvec8 m45 = (v54 - v52) * afp(0.25) + (v51 - v53) * afp(0.5);

afpvec8 m50 = v01 - v03 * afp(1.25) + v05 * afp(0.25);
afpvec8 m51 = v11 - v13 * afp(1.25) + v15 * afp(0.25);
afpvec8 m52 = v21 - v23 * afp(1.25) + v25 * afp(0.25);
afpvec8 m53 = v31 - v33 * afp(1.25) + v35 * afp(0.25);
afpvec8 m54 = v41 - v43 * afp(1.25) + v45 * afp(0.25);
afpvec8 m55 = v51 - v53 * afp(1.25) + v55 * afp(0.25);

v00 = m00 - m02 * afp(1.25) + m04 * afp(0.25);
v10 = m10 - m12 * afp(1.25) + m14 * afp(0.25);
v20 = m20 - m22 * afp(1.25) + m24 * afp(0.25);
v30 = m30 - m32 * afp(1.25) + m34 * afp(0.25);
v40 = m40 - m42 * afp(1.25) + m44 * afp(0.25);
v50 = m50 - m52 * afp(1.25) + m54 * afp(0.25);

v01 = (m04 + m03) * afp(0.25) - (m01 + m02);
v11 = (m14 + m13) * afp(0.25) - (m11 + m12);
v21 = (m24 + m23) * afp(0.25) - (m21 + m22);
v31 = (m34 + m33) * afp(0.25) - (m31 + m32);
v41 = (m44 + m43) * afp(0.25) - (m41 + m42);
v51 = (m54 + m53) * afp(0.25) - (m51 + m52);

v02 = (m04 - m03) * afp(0.25) + (m01 - m02);
v12 = (m14 - m13) * afp(0.25) + (m11 - m12);
v22 = (m24 - m23) * afp(0.25) + (m21 - m22);
v32 = (m34 - m33) * afp(0.25) + (m31 - m32);
v42 = (m44 - m43) * afp(0.25) + (m41 - m42);
v52 = (m54 - m53) * afp(0.25) + (m51 - m52);

v03 = (m04 - m02) * afp(0.25) - (m01 - m03) * afp(0.5);
v13 = (m14 - m12) * afp(0.25) - (m11 - m13) * afp(0.5);
v23 = (m24 - m22) * afp(0.25) - (m21 - m23) * afp(0.5);
v33 = (m34 - m32) * afp(0.25) - (m31 - m33) * afp(0.5);
v43 = (m44 - m42) * afp(0.25) - (m41 - m43) * afp(0.5);
v53 = (m54 - m52) * afp(0.25) - (m51 - m53) * afp(0.5);

v04 = (m04 - m02) * afp(0.25) + (m01 - m03) * afp(0.5);
v14 = (m14 - m12) * afp(0.25) + (m11 - m13) * afp(0.5);
v24 = (m24 - m22) * afp(0.25) + (m21 - m23) * afp(0.5);
v34 = (m34 - m32) * afp(0.25) + (m31 - m33) * afp(0.5);
v44 = (m44 - m42) * afp(0.25) + (m41 - m43) * afp(0.5);
v54 = (m54 - m52) * afp(0.25) + (m51 - m53) * afp(0.5);

v05 = m01 - m03 * afp(1.25) + m05 * afp(0.25);
v15 = m11 - m13 * afp(1.25) + m15 * afp(0.25);
v25 = m21 - m23 * afp(1.25) + m25 * afp(0.25);
v35 = m31 - m33 * afp(1.25) + m35 * afp(0.25);
v45 = m41 - m43 * afp(1.25) + m45 * afp(0.25);
v55 = m51 - m53 * afp(1.25) + m55 * afp(0.25);
afpvec8 m00 = v00 - v02 * afp(2.5) + v04;
afpvec8 m01 = v10 - v12 * afp(2.5) + v14;
afpvec8 m02 = v20 - v22 * afp(2.5) + v24;
afpvec8 m03 = v30 - v32 * afp(2.5) + v34;
afpvec8 m04 = v40 - v42 * afp(2.5) + v44;
afpvec8 m05 = v50 - v52 * afp(2.5) + v54;

afpvec8 s0 = v03 * afp(sq2_d2) - v01 * afp(sq2);
afpvec8 s1 = v13 * afp(sq2_d2) - v11 * afp(sq2);
afpvec8 s2 = v23 * afp(sq2_d2) - v21 * afp(sq2);
afpvec8 s3 = v33 * afp(sq2_d2) - v31 * afp(sq2);
afpvec8 s4 = v43 * afp(sq2_d2) - v41 * afp(sq2);
afpvec8 s5 = v53 * afp(sq2_d2) - v51 * afp(sq2);

afpvec8 t0 = v04 - v02 * afp(2);
afpvec8 t1 = v14 - v12 * afp(2);
afpvec8 t2 = v24 - v22 * afp(2);
afpvec8 t3 = v34 - v32 * afp(2);
afpvec8 t4 = v44 - v42 * afp(2);
afpvec8 t5 = v54 - v52 * afp(2);

afpvec8 m10 = t0 + s0;
afpvec8 m11 = t1 + s1;
afpvec8 m12 = t2 + s2;
afpvec8 m13 = t3 + s3;
afpvec8 m14 = t4 + s4;
afpvec8 m15 = t5 + s5;

afpvec8 m20 = t0 - s0;
afpvec8 m21 = t1 - s1;
afpvec8 m22 = t2 - s2;
afpvec8 m23 = t3 - s3;
afpvec8 m24 = t4 - s4;
afpvec8 m25 = t5 - s5;

s0 = v03 * afp(sq2) - v01 * afp(sq2_d2);
s1 = v13 * afp(sq2) - v11 * afp(sq2_d2);
s2 = v23 * afp(sq2) - v21 * afp(sq2_d2);
s3 = v33 * afp(sq2) - v31 * afp(sq2_d2);
s4 = v43 * afp(sq2) - v41 * afp(sq2_d2);
s5 = v53 * afp(sq2) - v51 * afp(sq2_d2);

t0 = v04 - v02 * afp(0.5);
t1 = v14 - v12 * afp(0.5);
t2 = v24 - v22 * afp(0.5);
t3 = v34 - v32 * afp(0.5);
t4 = v44 - v42 * afp(0.5);
t5 = v54 - v52 * afp(0.5);

afpvec8 m30 = t0 + s0;
afpvec8 m31 = t1 + s1;
afpvec8 m32 = t2 + s2;
afpvec8 m33 = t3 + s3;
afpvec8 m34 = t4 + s4;
afpvec8 m35 = t5 + s5;

afpvec8 m40 = t0 - s0;
afpvec8 m41 = t1 - s1;
afpvec8 m42 = t2 - s2;
afpvec8 m43 = t3 - s3;
afpvec8 m44 = t4 - s4;
afpvec8 m45 = t5 - s5;

afpvec8 m50 = v01 - v03 * afp(2.5) + v05;
afpvec8 m51 = v11 - v13 * afp(2.5) + v15;
afpvec8 m52 = v21 - v23 * afp(2.5) + v25;
afpvec8 m53 = v31 - v33 * afp(2.5) + v35;
afpvec8 m54 = v41 - v43 * afp(2.5) + v45;
afpvec8 m55 = v51 - v53 * afp(2.5) + v55;

v00 = m00 - m02 * afp(2.5) + m04;
v10 = m10 - m12 * afp(2.5) + m14;
v20 = m20 - m22 * afp(2.5) + m24;
v30 = m30 - m32 * afp(2.5) + m34;
v40 = m40 - m42 * afp(2.5) + m44;
v50 = m50 - m52 * afp(2.5) + m54;

s0 = m03 * afp(sq2_d2) - m01 * afp(sq2);
s1 = m13 * afp(sq2_d2) - m11 * afp(sq2);
s2 = m23 * afp(sq2_d2) - m21 * afp(sq2);
s3 = m33 * afp(sq2_d2) - m31 * afp(sq2);
s4 = m43 * afp(sq2_d2) - m41 * afp(sq2);
s5 = m53 * afp(sq2_d2) - m51 * afp(sq2);

t0 = m04 - m02 * afp(2);
t1 = m14 - m12 * afp(2);
t2 = m24 - m22 * afp(2);
t3 = m34 - m32 * afp(2);
t4 = m44 - m42 * afp(2);
t5 = m54 - m52 * afp(2);

v01 = t0 + s0;
v11 = t1 + s1;
v21 = t2 + s2;
v31 = t3 + s3;
v41 = t4 + s4;
v51 = t5 + s5;

v02 = t0 - s0;
v12 = t1 - s1;
v22 = t2 - s2;
v32 = t3 - s3;
v42 = t4 - s4;
v52 = t5 - s5;

s0 = m03 * afp(sq2) - m01 * afp(sq2_d2);
s1 = m13 * afp(sq2) - m11 * afp(sq2_d2);
s2 = m23 * afp(sq2) - m21 * afp(sq2_d2);
s3 = m33 * afp(sq2) - m31 * afp(sq2_d2);
s4 = m43 * afp(sq2) - m41 * afp(sq2_d2);
s5 = m53 * afp(sq2) - m51 * afp(sq2_d2);

t0 = m04 - m02 * afp(0.5);
t1 = m14 - m12 * afp(0.5);
t2 = m24 - m22 * afp(0.5);
t3 = m34 - m32 * afp(0.5);
t4 = m44 - m42 * afp(0.5);
t5 = m54 - m52 * afp(0.5);

v03 = t0 + s0;
v13 = t1 + s1;
v23 = t2 + s2;
v33 = t3 + s3;
v43 = t4 + s4;
v53 = t5 + s5;

v04 = t0 - s0;
v14 = t1 - s1;
v24 = t2 - s2;
v34 = t3 - s3;
v44 = t4 - s4;
v54 = t5 - s5;

v05 = m01 - m03 * afp(2.5) + m05;
v15 = m11 - m13 * afp(2.5) + m15;
v25 = m21 - m23 * afp(2.5) + m25;
v35 = m31 - m33 * afp(2.5) + m35;
v45 = m41 - m43 * afp(2.5) + m45;
v55 = m51 - m53 * afp(2.5) + m55;

// store 36
#if NCNN_image_shader


+ 102
- 56
src/layer/vulkan/shader/convolution_pack8_3x3s1d1_winograd43_transform_output.comp View File

@@ -154,66 +154,112 @@ void main()
afpvec8 v55 = buffer_ld8(top_tm_blob_data, v_tm_offset + 35 * psc(cstep));
#endif

#define sq2 1.41421356237
#define sq2_d4 1.41421356237/4

// const float otm[4][6] = {
// {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f},
// {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f},
// {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f},
// {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f}
// {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f},
// {0.0f, sq2/2, -sq2/2, sq2, -sq2, 0.0f},
// {0.0f, 0.5f, 0.5f, 2.0f, 2.0f, 0.0f},
// {0.0f, sq2/4, -sq2/4, sq2*2, -sq2*2, 1.0f}
// };

// 0 = r00 + (r01 + r02) + (r03 + r04)
// 1 = (r01 - r02) + (r03 - r04) * 2
// 2 = (r01 + r02) + (r03 + r04) * 4
// 3 = r05 + (r01 - r02) + (r03 - r04) * 8

// implicit transpose
afpvec8 m00 = v00 + (v01 + v02) + (v03 + v04);
afpvec8 m01 = v10 + (v11 + v12) + (v13 + v14);
afpvec8 m02 = v20 + (v21 + v22) + (v23 + v24);
afpvec8 m03 = v30 + (v31 + v32) + (v33 + v34);
afpvec8 m04 = v40 + (v41 + v42) + (v43 + v44);
afpvec8 m05 = v50 + (v51 + v52) + (v53 + v54);

afpvec8 m10 = (v01 - v02) + (v03 - v04) * afp(2);
afpvec8 m11 = (v11 - v12) + (v13 - v14) * afp(2);
afpvec8 m12 = (v21 - v22) + (v23 - v24) * afp(2);
afpvec8 m13 = (v31 - v32) + (v33 - v34) * afp(2);
afpvec8 m14 = (v41 - v42) + (v43 - v44) * afp(2);
afpvec8 m15 = (v51 - v52) + (v53 - v54) * afp(2);

afpvec8 m20 = (v01 + v02) + (v03 + v04) * afp(4);
afpvec8 m21 = (v11 + v12) + (v13 + v14) * afp(4);
afpvec8 m22 = (v21 + v22) + (v23 + v24) * afp(4);
afpvec8 m23 = (v31 + v32) + (v33 + v34) * afp(4);
afpvec8 m24 = (v41 + v42) + (v43 + v44) * afp(4);
afpvec8 m25 = (v51 + v52) + (v53 + v54) * afp(4);

afpvec8 m30 = v05 + (v01 - v02) + (v03 - v04) * afp(8);
afpvec8 m31 = v15 + (v11 - v12) + (v13 - v14) * afp(8);
afpvec8 m32 = v25 + (v21 - v22) + (v23 - v24) * afp(8);
afpvec8 m33 = v35 + (v31 - v32) + (v33 - v34) * afp(8);
afpvec8 m34 = v45 + (v41 - v42) + (v43 - v44) * afp(8);
afpvec8 m35 = v55 + (v51 - v52) + (v53 - v54) * afp(8);

v00 = m00 + (m01 + m02) + (m03 + m04);
v10 = m10 + (m11 + m12) + (m13 + m14);
v20 = m20 + (m21 + m22) + (m23 + m24);
v30 = m30 + (m31 + m32) + (m33 + m34);

v01 = (m01 - m02) + (m03 - m04) * afp(2);
v11 = (m11 - m12) + (m13 - m14) * afp(2);
v21 = (m21 - m22) + (m23 - m24) * afp(2);
v31 = (m31 - m32) + (m33 - m34) * afp(2);

v02 = (m01 + m02) + (m03 + m04) * afp(4);
v12 = (m11 + m12) + (m13 + m14) * afp(4);
v22 = (m21 + m22) + (m23 + m24) * afp(4);
v32 = (m31 + m32) + (m33 + m34) * afp(4);

v03 = m05 + (m01 - m02) + (m03 - m04) * afp(8);
v13 = m15 + (m11 - m12) + (m13 - m14) * afp(8);
v23 = m25 + (m21 - m22) + (m23 - m24) * afp(8);
v33 = m35 + (m31 - m32) + (m33 - m34) * afp(8);
afpvec8 s0 = (v01 + v02) * afp(0.5);
afpvec8 s1 = (v11 + v12) * afp(0.5);
afpvec8 s2 = (v21 + v22) * afp(0.5);
afpvec8 s3 = (v31 + v32) * afp(0.5);
afpvec8 s4 = (v41 + v42) * afp(0.5);
afpvec8 s5 = (v51 + v52) * afp(0.5);

afpvec8 t0 = (v01 - v02) * afp(sq2_d4);
afpvec8 t1 = (v11 - v12) * afp(sq2_d4);
afpvec8 t2 = (v21 - v22) * afp(sq2_d4);
afpvec8 t3 = (v31 - v32) * afp(sq2_d4);
afpvec8 t4 = (v41 - v42) * afp(sq2_d4);
afpvec8 t5 = (v51 - v52) * afp(sq2_d4);

afpvec8 u0 = v03 + v04;
afpvec8 u1 = v13 + v14;
afpvec8 u2 = v23 + v24;
afpvec8 u3 = v33 + v34;
afpvec8 u4 = v43 + v44;
afpvec8 u5 = v53 + v54;

afpvec8 v0 = (v03 - v04) * afp(sq2);
afpvec8 v1 = (v13 - v14) * afp(sq2);
afpvec8 v2 = (v23 - v24) * afp(sq2);
afpvec8 v3 = (v33 - v34) * afp(sq2);
afpvec8 v4 = (v43 - v44) * afp(sq2);
afpvec8 v5 = (v53 - v54) * afp(sq2);

afpvec8 m00 = v00 + s0 + s0 + u0;
afpvec8 m01 = v10 + s1 + s1 + u1;
afpvec8 m02 = v20 + s2 + s2 + u2;
afpvec8 m03 = v30 + s3 + s3 + u3;
afpvec8 m04 = v40 + s4 + s4 + u4;
afpvec8 m05 = v50 + s5 + s5 + u5;

afpvec8 m10 = t0 + t0 + v0;
afpvec8 m11 = t1 + t1 + v1;
afpvec8 m12 = t2 + t2 + v2;
afpvec8 m13 = t3 + t3 + v3;
afpvec8 m14 = t4 + t4 + v4;
afpvec8 m15 = t5 + t5 + v5;

afpvec8 m20 = s0 + u0 + u0;
afpvec8 m21 = s1 + u1 + u1;
afpvec8 m22 = s2 + u2 + u2;
afpvec8 m23 = s3 + u3 + u3;
afpvec8 m24 = s4 + u4 + u4;
afpvec8 m25 = s5 + u5 + u5;

afpvec8 m30 = v05 + t0 + v0 + v0;
afpvec8 m31 = v15 + t1 + v1 + v1;
afpvec8 m32 = v25 + t2 + v2 + v2;
afpvec8 m33 = v35 + t3 + v3 + v3;
afpvec8 m34 = v45 + t4 + v4 + v4;
afpvec8 m35 = v55 + t5 + v5 + v5;

s0 = (m01 + m02) * afp(0.5);
s1 = (m11 + m12) * afp(0.5);
s2 = (m21 + m22) * afp(0.5);
s3 = (m31 + m32) * afp(0.5);

t0 = (m01 - m02) * afp(sq2_d4);
t1 = (m11 - m12) * afp(sq2_d4);
t2 = (m21 - m22) * afp(sq2_d4);
t3 = (m31 - m32) * afp(sq2_d4);

u0 = m03 + m04;
u1 = m13 + m14;
u2 = m23 + m24;
u3 = m33 + m34;

v0 = (m03 - m04) * afp(sq2);
v1 = (m13 - m14) * afp(sq2);
v2 = (m23 - m24) * afp(sq2);
v3 = (m33 - m34) * afp(sq2);

v00 = m00 + s0 + s0 + u0;
v10 = m10 + s1 + s1 + u1;
v20 = m20 + s2 + s2 + u2;
v30 = m30 + s3 + s3 + u3;

v01 = t0 + t0 + v0;
v11 = t1 + t1 + v1;
v21 = t2 + t2 + v2;
v31 = t3 + t3 + v3;

v02 = s0 + u0 + u0;
v12 = s1 + u1 + u1;
v22 = s2 + u2 + u2;
v32 = s3 + u3 + u3;

v03 = m05 + t0 + v0 + v0;
v13 = m15 + t1 + v1 + v1;
v23 = m25 + t2 + v2 + v2;
v33 = m35 + t3 + v3 + v3;

if (bias_term == 1)
{


+ 0
- 10
tests/test_convolution.cpp View File

@@ -41,16 +41,6 @@ static int test_convolution(int w, int h, int c, int outch, int kernel, int dila
weights[1] = RandomMat(outch);

float epsilon = 0.001;
// larget epsilon for winograd optimization
if (kernel == 3 && dilation == 1 && stride == 1 && c >= 16 && outch >= 16)
{
Randomize(a, -1, 1);
if (c >= 64 || outch >= 64)
Randomize(weights[0], -0.3, 0.3);
else
Randomize(weights[0], -1, 1);
epsilon = 0.002;
}

int ret = test_layer<ncnn::Convolution>("Convolution", pd, weights, a, epsilon);
if (ret != 0)


+ 0
- 10
tests/test_convolution_1.cpp View File

@@ -41,16 +41,6 @@ static int test_convolution(int w, int h, int c, int outch, int kernel, int dila
weights[1] = RandomMat(outch);

float epsilon = 0.001;
// larget epsilon for winograd optimization
if (kernel == 3 && dilation == 1 && stride == 1 && c >= 16 && outch >= 16)
{
Randomize(a, -1, 1);
if (c >= 64 || outch >= 64)
Randomize(weights[0], -0.3, 0.3);
else
Randomize(weights[0], -1, 1);
epsilon = 0.002;
}

int ret = test_layer<ncnn::Convolution>("Convolution", pd, weights, a, epsilon);
if (ret != 0)


+ 0
- 10
tests/test_convolution_2.cpp View File

@@ -41,16 +41,6 @@ static int test_convolution(int w, int h, int c, int outch, int kernel, int dila
weights[1] = RandomMat(outch);

float epsilon = 0.001;
// larget epsilon for winograd optimization
if (kernel == 3 && dilation == 1 && stride == 1 && c >= 16 && outch >= 16)
{
Randomize(a, -1, 1);
if (c >= 64 || outch >= 64)
Randomize(weights[0], -0.3, 0.3);
else
Randomize(weights[0], -1, 1);
epsilon = 0.002;
}

int ret = test_layer<ncnn::Convolution>("Convolution", pd, weights, a, epsilon);
if (ret != 0)


Loading…
Cancel
Save