diff --git a/src/layer/vulkan/convolution_vulkan.cpp b/src/layer/vulkan/convolution_vulkan.cpp index 22e817d34..caab40f50 100644 --- a/src/layer/vulkan/convolution_vulkan.cpp +++ b/src/layer/vulkan/convolution_vulkan.cpp @@ -188,13 +188,14 @@ int Convolution_vulkan::create_pipeline(const Option& _opt) Mat weight_data_tm; weight_data_tm.create(6 * 6, num_input, num_output); + const float sq2 = 1.41421356237f; const float ktm[6][3] = { {1.0f, 0.0f, 0.0f}, - {-2.0f / 3, -2.0f / 3, -2.0f / 3}, - {-2.0f / 3, 2.0f / 3, -2.0f / 3}, - {1.0f / 6, 1.0f / 3, 2.0f / 3}, - {1.0f / 6, -1.0f / 3, 2.0f / 3}, - {0.0f, 0.0f, 4.0f} + {-2.0f / 3, -sq2 / 3, -1.0f / 3}, + {-2.0f / 3, sq2 / 3, -1.0f / 3}, + {1.0f / 6, sq2 / 6, 1.0f / 3}, + {1.0f / 6, -sq2 / 6, 1.0f / 3}, + {0.0f, 0.0f, 1.0f} }; #pragma omp parallel for num_threads(opt.num_threads) diff --git a/src/layer/vulkan/shader/convolution_3x3s1d1_winograd43_transform_input.comp b/src/layer/vulkan/shader/convolution_3x3s1d1_winograd43_transform_input.comp index b3026ca11..d989d65e8 100644 --- a/src/layer/vulkan/shader/convolution_3x3s1d1_winograd43_transform_input.comp +++ b/src/layer/vulkan/shader/convolution_3x3s1d1_winograd43_transform_input.comp @@ -156,106 +156,158 @@ void main() afp v55 = sy + 5 < psc(h) && sx + 5 < psc(w) ? buffer_ld1(bottom_blob_data, v_offset45.y + 5) : afp(0.f); #endif +#define sq2 1.41421356237 +#define sq2_d2 1.41421356237/2 + // const float itm[6][6] = { - // {1.0f, 0.0f, -1.25f, 0.0f, 0.25f, 0.0f}, - // {0.0f,-1.0f, -1.0f, 0.25f, 0.25f, 0.0f}, - // {0.0f, 1.0f, -1.0f,-0.25f, 0.25f, 0.0f}, - // {0.0f,-0.5f, -0.25f, 0.5f, 0.25f, 0.0f}, - // {0.0f, 0.5f, -0.25f,-0.5f, 0.25f, 0.0f}, - // {0.0f, 1.0f, 0.0f,-1.25f, 0.0f, 0.25f} + // {1.0f, 0.0f, -2.5f, 0.0f, 1.0f, 0.0f}, + // {0.0f, -sq2, -2.0f, sq2/2, 1.0f, 0.0f}, + // {0.0f, sq2, -2.0f, sq2/2, 1.0f, 0.0f}, + // {0.0f, -sq2/2, -0.5f, sq2, 1.0f, 0.0f}, + // {0.0f, sq2/2, -0.5f, -sq2, 1.0f, 0.0f}, + // {0.0f, 1.0f, 0.0f, -2.5f, 0.0f, 1.0f} // }; - // 0 = 1 * r00 - 1.25 * r02 + 0.25 * r04 - // 1 = -1 * (r01 + r02) + 0.25 * (r04 + r03) - // 2 = 1 * (r01 - r02) + 0.25 * (r04 - r03) - // 3 = -0.5 * (r01 - r03) + 0.25 * (r04 - r02) - // 4 = 0.5 * (r01 - r03) + 0.25 * (r04 - r02) - // 5 = 1 * r01 - 1.25 * r03 + 0.25 * r05 - // implicit transpose - afp m00 = v00 - v02 * afp(1.25) + v04 * afp(0.25); - afp m01 = v10 - v12 * afp(1.25) + v14 * afp(0.25); - afp m02 = v20 - v22 * afp(1.25) + v24 * afp(0.25); - afp m03 = v30 - v32 * afp(1.25) + v34 * afp(0.25); - afp m04 = v40 - v42 * afp(1.25) + v44 * afp(0.25); - afp m05 = v50 - v52 * afp(1.25) + v54 * afp(0.25); - - afp m10 = (v04 + v03) * afp(0.25) - (v01 + v02); - afp m11 = (v14 + v13) * afp(0.25) - (v11 + v12); - afp m12 = (v24 + v23) * afp(0.25) - (v21 + v22); - afp m13 = (v34 + v33) * afp(0.25) - (v31 + v32); - afp m14 = (v44 + v43) * afp(0.25) - (v41 + v42); - afp m15 = (v54 + v53) * afp(0.25) - (v51 + v52); - - afp m20 = (v04 - v03) * afp(0.25) + (v01 - v02); - afp m21 = (v14 - v13) * afp(0.25) + (v11 - v12); - afp m22 = (v24 - v23) * afp(0.25) + (v21 - v22); - afp m23 = (v34 - v33) * afp(0.25) + (v31 - v32); - afp m24 = (v44 - v43) * afp(0.25) + (v41 - v42); - afp m25 = (v54 - v53) * afp(0.25) + (v51 - v52); - - afp m30 = (v04 - v02) * afp(0.25) - (v01 - v03) * afp(0.5); - afp m31 = (v14 - v12) * afp(0.25) - (v11 - v13) * afp(0.5); - afp m32 = (v24 - v22) * afp(0.25) - (v21 - v23) * afp(0.5); - afp m33 = (v34 - v32) * afp(0.25) - (v31 - v33) * afp(0.5); - afp m34 = (v44 - v42) * afp(0.25) - (v41 - v43) * afp(0.5); - afp m35 = (v54 - v52) * afp(0.25) - (v51 - v53) * afp(0.5); - - afp m40 = (v04 - v02) * afp(0.25) + (v01 - v03) * afp(0.5); - afp m41 = (v14 - v12) * afp(0.25) + (v11 - v13) * afp(0.5); - afp m42 = (v24 - v22) * afp(0.25) + (v21 - v23) * afp(0.5); - afp m43 = (v34 - v32) * afp(0.25) + (v31 - v33) * afp(0.5); - afp m44 = (v44 - v42) * afp(0.25) + (v41 - v43) * afp(0.5); - afp m45 = (v54 - v52) * afp(0.25) + (v51 - v53) * afp(0.5); - - afp m50 = v01 - v03 * afp(1.25) + v05 * afp(0.25); - afp m51 = v11 - v13 * afp(1.25) + v15 * afp(0.25); - afp m52 = v21 - v23 * afp(1.25) + v25 * afp(0.25); - afp m53 = v31 - v33 * afp(1.25) + v35 * afp(0.25); - afp m54 = v41 - v43 * afp(1.25) + v45 * afp(0.25); - afp m55 = v51 - v53 * afp(1.25) + v55 * afp(0.25); - - v00 = m00 - m02 * afp(1.25) + m04 * afp(0.25); - v10 = m10 - m12 * afp(1.25) + m14 * afp(0.25); - v20 = m20 - m22 * afp(1.25) + m24 * afp(0.25); - v30 = m30 - m32 * afp(1.25) + m34 * afp(0.25); - v40 = m40 - m42 * afp(1.25) + m44 * afp(0.25); - v50 = m50 - m52 * afp(1.25) + m54 * afp(0.25); - - v01 = (m04 + m03) * afp(0.25) - (m01 + m02); - v11 = (m14 + m13) * afp(0.25) - (m11 + m12); - v21 = (m24 + m23) * afp(0.25) - (m21 + m22); - v31 = (m34 + m33) * afp(0.25) - (m31 + m32); - v41 = (m44 + m43) * afp(0.25) - (m41 + m42); - v51 = (m54 + m53) * afp(0.25) - (m51 + m52); - - v02 = (m04 - m03) * afp(0.25) + (m01 - m02); - v12 = (m14 - m13) * afp(0.25) + (m11 - m12); - v22 = (m24 - m23) * afp(0.25) + (m21 - m22); - v32 = (m34 - m33) * afp(0.25) + (m31 - m32); - v42 = (m44 - m43) * afp(0.25) + (m41 - m42); - v52 = (m54 - m53) * afp(0.25) + (m51 - m52); - - v03 = (m04 - m02) * afp(0.25) - (m01 - m03) * afp(0.5); - v13 = (m14 - m12) * afp(0.25) - (m11 - m13) * afp(0.5); - v23 = (m24 - m22) * afp(0.25) - (m21 - m23) * afp(0.5); - v33 = (m34 - m32) * afp(0.25) - (m31 - m33) * afp(0.5); - v43 = (m44 - m42) * afp(0.25) - (m41 - m43) * afp(0.5); - v53 = (m54 - m52) * afp(0.25) - (m51 - m53) * afp(0.5); - - v04 = (m04 - m02) * afp(0.25) + (m01 - m03) * afp(0.5); - v14 = (m14 - m12) * afp(0.25) + (m11 - m13) * afp(0.5); - v24 = (m24 - m22) * afp(0.25) + (m21 - m23) * afp(0.5); - v34 = (m34 - m32) * afp(0.25) + (m31 - m33) * afp(0.5); - v44 = (m44 - m42) * afp(0.25) + (m41 - m43) * afp(0.5); - v54 = (m54 - m52) * afp(0.25) + (m51 - m53) * afp(0.5); - - v05 = m01 - m03 * afp(1.25) + m05 * afp(0.25); - v15 = m11 - m13 * afp(1.25) + m15 * afp(0.25); - v25 = m21 - m23 * afp(1.25) + m25 * afp(0.25); - v35 = m31 - m33 * afp(1.25) + m35 * afp(0.25); - v45 = m41 - m43 * afp(1.25) + m45 * afp(0.25); - v55 = m51 - m53 * afp(1.25) + m55 * afp(0.25); + afp m00 = v00 - v02 * afp(2.5) + v04; + afp m01 = v10 - v12 * afp(2.5) + v14; + afp m02 = v20 - v22 * afp(2.5) + v24; + afp m03 = v30 - v32 * afp(2.5) + v34; + afp m04 = v40 - v42 * afp(2.5) + v44; + afp m05 = v50 - v52 * afp(2.5) + v54; + + afp s0 = v03 * afp(sq2_d2) - v01 * afp(sq2); + afp s1 = v13 * afp(sq2_d2) - v11 * afp(sq2); + afp s2 = v23 * afp(sq2_d2) - v21 * afp(sq2); + afp s3 = v33 * afp(sq2_d2) - v31 * afp(sq2); + afp s4 = v43 * afp(sq2_d2) - v41 * afp(sq2); + afp s5 = v53 * afp(sq2_d2) - v51 * afp(sq2); + + afp t0 = v04 - v02 * afp(2); + afp t1 = v14 - v12 * afp(2); + afp t2 = v24 - v22 * afp(2); + afp t3 = v34 - v32 * afp(2); + afp t4 = v44 - v42 * afp(2); + afp t5 = v54 - v52 * afp(2); + + afp m10 = t0 + s0; + afp m11 = t1 + s1; + afp m12 = t2 + s2; + afp m13 = t3 + s3; + afp m14 = t4 + s4; + afp m15 = t5 + s5; + + afp m20 = t0 - s0; + afp m21 = t1 - s1; + afp m22 = t2 - s2; + afp m23 = t3 - s3; + afp m24 = t4 - s4; + afp m25 = t5 - s5; + + s0 = v03 * afp(sq2) - v01 * afp(sq2_d2); + s1 = v13 * afp(sq2) - v11 * afp(sq2_d2); + s2 = v23 * afp(sq2) - v21 * afp(sq2_d2); + s3 = v33 * afp(sq2) - v31 * afp(sq2_d2); + s4 = v43 * afp(sq2) - v41 * afp(sq2_d2); + s5 = v53 * afp(sq2) - v51 * afp(sq2_d2); + + t0 = v04 - v02 * afp(0.5); + t1 = v14 - v12 * afp(0.5); + t2 = v24 - v22 * afp(0.5); + t3 = v34 - v32 * afp(0.5); + t4 = v44 - v42 * afp(0.5); + t5 = v54 - v52 * afp(0.5); + + afp m30 = t0 + s0; + afp m31 = t1 + s1; + afp m32 = t2 + s2; + afp m33 = t3 + s3; + afp m34 = t4 + s4; + afp m35 = t5 + s5; + + afp m40 = t0 - s0; + afp m41 = t1 - s1; + afp m42 = t2 - s2; + afp m43 = t3 - s3; + afp m44 = t4 - s4; + afp m45 = t5 - s5; + + afp m50 = v01 - v03 * afp(2.5) + v05; + afp m51 = v11 - v13 * afp(2.5) + v15; + afp m52 = v21 - v23 * afp(2.5) + v25; + afp m53 = v31 - v33 * afp(2.5) + v35; + afp m54 = v41 - v43 * afp(2.5) + v45; + afp m55 = v51 - v53 * afp(2.5) + v55; + + v00 = m00 - m02 * afp(2.5) + m04; + v10 = m10 - m12 * afp(2.5) + m14; + v20 = m20 - m22 * afp(2.5) + m24; + v30 = m30 - m32 * afp(2.5) + m34; + v40 = m40 - m42 * afp(2.5) + m44; + v50 = m50 - m52 * afp(2.5) + m54; + + s0 = m03 * afp(sq2_d2) - m01 * afp(sq2); + s1 = m13 * afp(sq2_d2) - m11 * afp(sq2); + s2 = m23 * afp(sq2_d2) - m21 * afp(sq2); + s3 = m33 * afp(sq2_d2) - m31 * afp(sq2); + s4 = m43 * afp(sq2_d2) - m41 * afp(sq2); + s5 = m53 * afp(sq2_d2) - m51 * afp(sq2); + + t0 = m04 - m02 * afp(2); + t1 = m14 - m12 * afp(2); + t2 = m24 - m22 * afp(2); + t3 = m34 - m32 * afp(2); + t4 = m44 - m42 * afp(2); + t5 = m54 - m52 * afp(2); + + v01 = t0 + s0; + v11 = t1 + s1; + v21 = t2 + s2; + v31 = t3 + s3; + v41 = t4 + s4; + v51 = t5 + s5; + + v02 = t0 - s0; + v12 = t1 - s1; + v22 = t2 - s2; + v32 = t3 - s3; + v42 = t4 - s4; + v52 = t5 - s5; + + s0 = m03 * afp(sq2) - m01 * afp(sq2_d2); + s1 = m13 * afp(sq2) - m11 * afp(sq2_d2); + s2 = m23 * afp(sq2) - m21 * afp(sq2_d2); + s3 = m33 * afp(sq2) - m31 * afp(sq2_d2); + s4 = m43 * afp(sq2) - m41 * afp(sq2_d2); + s5 = m53 * afp(sq2) - m51 * afp(sq2_d2); + + t0 = m04 - m02 * afp(0.5); + t1 = m14 - m12 * afp(0.5); + t2 = m24 - m22 * afp(0.5); + t3 = m34 - m32 * afp(0.5); + t4 = m44 - m42 * afp(0.5); + t5 = m54 - m52 * afp(0.5); + + v03 = t0 + s0; + v13 = t1 + s1; + v23 = t2 + s2; + v33 = t3 + s3; + v43 = t4 + s4; + v53 = t5 + s5; + + v04 = t0 - s0; + v14 = t1 - s1; + v24 = t2 - s2; + v34 = t3 - s3; + v44 = t4 - s4; + v54 = t5 - s5; + + v05 = m01 - m03 * afp(2.5) + m05; + v15 = m11 - m13 * afp(2.5) + m15; + v25 = m21 - m23 * afp(2.5) + m25; + v35 = m31 - m33 * afp(2.5) + m35; + v45 = m41 - m43 * afp(2.5) + m45; + v55 = m51 - m53 * afp(2.5) + m55; // store 36 #if NCNN_image_shader diff --git a/src/layer/vulkan/shader/convolution_3x3s1d1_winograd43_transform_output.comp b/src/layer/vulkan/shader/convolution_3x3s1d1_winograd43_transform_output.comp index 4da8e8f51..c70f596cd 100644 --- a/src/layer/vulkan/shader/convolution_3x3s1d1_winograd43_transform_output.comp +++ b/src/layer/vulkan/shader/convolution_3x3s1d1_winograd43_transform_output.comp @@ -153,66 +153,112 @@ void main() afp v55 = buffer_ld1(top_tm_blob_data, v_tm_offset + 35 * psc(cstep)); #endif +#define sq2 1.41421356237 +#define sq2_d4 1.41421356237/4 + // const float otm[4][6] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} + // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, + // {0.0f, sq2/2, -sq2/2, sq2, -sq2, 0.0f}, + // {0.0f, 0.5f, 0.5f, 2.0f, 2.0f, 0.0f}, + // {0.0f, sq2/4, -sq2/4, sq2*2, -sq2*2, 1.0f} // }; - // 0 = r00 + (r01 + r02) + (r03 + r04) - // 1 = (r01 - r02) + (r03 - r04) * 2 - // 2 = (r01 + r02) + (r03 + r04) * 4 - // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 - // implicit transpose - afp m00 = v00 + (v01 + v02) + (v03 + v04); - afp m01 = v10 + (v11 + v12) + (v13 + v14); - afp m02 = v20 + (v21 + v22) + (v23 + v24); - afp m03 = v30 + (v31 + v32) + (v33 + v34); - afp m04 = v40 + (v41 + v42) + (v43 + v44); - afp m05 = v50 + (v51 + v52) + (v53 + v54); - - afp m10 = (v01 - v02) + (v03 - v04) * afp(2); - afp m11 = (v11 - v12) + (v13 - v14) * afp(2); - afp m12 = (v21 - v22) + (v23 - v24) * afp(2); - afp m13 = (v31 - v32) + (v33 - v34) * afp(2); - afp m14 = (v41 - v42) + (v43 - v44) * afp(2); - afp m15 = (v51 - v52) + (v53 - v54) * afp(2); - - afp m20 = (v01 + v02) + (v03 + v04) * afp(4); - afp m21 = (v11 + v12) + (v13 + v14) * afp(4); - afp m22 = (v21 + v22) + (v23 + v24) * afp(4); - afp m23 = (v31 + v32) + (v33 + v34) * afp(4); - afp m24 = (v41 + v42) + (v43 + v44) * afp(4); - afp m25 = (v51 + v52) + (v53 + v54) * afp(4); - - afp m30 = v05 + (v01 - v02) + (v03 - v04) * afp(8); - afp m31 = v15 + (v11 - v12) + (v13 - v14) * afp(8); - afp m32 = v25 + (v21 - v22) + (v23 - v24) * afp(8); - afp m33 = v35 + (v31 - v32) + (v33 - v34) * afp(8); - afp m34 = v45 + (v41 - v42) + (v43 - v44) * afp(8); - afp m35 = v55 + (v51 - v52) + (v53 - v54) * afp(8); - - v00 = m00 + (m01 + m02) + (m03 + m04); - v10 = m10 + (m11 + m12) + (m13 + m14); - v20 = m20 + (m21 + m22) + (m23 + m24); - v30 = m30 + (m31 + m32) + (m33 + m34); - - v01 = (m01 - m02) + (m03 - m04) * afp(2); - v11 = (m11 - m12) + (m13 - m14) * afp(2); - v21 = (m21 - m22) + (m23 - m24) * afp(2); - v31 = (m31 - m32) + (m33 - m34) * afp(2); - - v02 = (m01 + m02) + (m03 + m04) * afp(4); - v12 = (m11 + m12) + (m13 + m14) * afp(4); - v22 = (m21 + m22) + (m23 + m24) * afp(4); - v32 = (m31 + m32) + (m33 + m34) * afp(4); - - v03 = m05 + (m01 - m02) + (m03 - m04) * afp(8); - v13 = m15 + (m11 - m12) + (m13 - m14) * afp(8); - v23 = m25 + (m21 - m22) + (m23 - m24) * afp(8); - v33 = m35 + (m31 - m32) + (m33 - m34) * afp(8); + afp s0 = (v01 + v02) * afp(0.5); + afp s1 = (v11 + v12) * afp(0.5); + afp s2 = (v21 + v22) * afp(0.5); + afp s3 = (v31 + v32) * afp(0.5); + afp s4 = (v41 + v42) * afp(0.5); + afp s5 = (v51 + v52) * afp(0.5); + + afp t0 = (v01 - v02) * afp(sq2_d4); + afp t1 = (v11 - v12) * afp(sq2_d4); + afp t2 = (v21 - v22) * afp(sq2_d4); + afp t3 = (v31 - v32) * afp(sq2_d4); + afp t4 = (v41 - v42) * afp(sq2_d4); + afp t5 = (v51 - v52) * afp(sq2_d4); + + afp u0 = v03 + v04; + afp u1 = v13 + v14; + afp u2 = v23 + v24; + afp u3 = v33 + v34; + afp u4 = v43 + v44; + afp u5 = v53 + v54; + + afp v0 = (v03 - v04) * afp(sq2); + afp v1 = (v13 - v14) * afp(sq2); + afp v2 = (v23 - v24) * afp(sq2); + afp v3 = (v33 - v34) * afp(sq2); + afp v4 = (v43 - v44) * afp(sq2); + afp v5 = (v53 - v54) * afp(sq2); + + afp m00 = v00 + s0 + s0 + u0; + afp m01 = v10 + s1 + s1 + u1; + afp m02 = v20 + s2 + s2 + u2; + afp m03 = v30 + s3 + s3 + u3; + afp m04 = v40 + s4 + s4 + u4; + afp m05 = v50 + s5 + s5 + u5; + + afp m10 = t0 + t0 + v0; + afp m11 = t1 + t1 + v1; + afp m12 = t2 + t2 + v2; + afp m13 = t3 + t3 + v3; + afp m14 = t4 + t4 + v4; + afp m15 = t5 + t5 + v5; + + afp m20 = s0 + u0 + u0; + afp m21 = s1 + u1 + u1; + afp m22 = s2 + u2 + u2; + afp m23 = s3 + u3 + u3; + afp m24 = s4 + u4 + u4; + afp m25 = s5 + u5 + u5; + + afp m30 = v05 + t0 + v0 + v0; + afp m31 = v15 + t1 + v1 + v1; + afp m32 = v25 + t2 + v2 + v2; + afp m33 = v35 + t3 + v3 + v3; + afp m34 = v45 + t4 + v4 + v4; + afp m35 = v55 + t5 + v5 + v5; + + s0 = (m01 + m02) * afp(0.5); + s1 = (m11 + m12) * afp(0.5); + s2 = (m21 + m22) * afp(0.5); + s3 = (m31 + m32) * afp(0.5); + + t0 = (m01 - m02) * afp(sq2_d4); + t1 = (m11 - m12) * afp(sq2_d4); + t2 = (m21 - m22) * afp(sq2_d4); + t3 = (m31 - m32) * afp(sq2_d4); + + u0 = m03 + m04; + u1 = m13 + m14; + u2 = m23 + m24; + u3 = m33 + m34; + + v0 = (m03 - m04) * afp(sq2); + v1 = (m13 - m14) * afp(sq2); + v2 = (m23 - m24) * afp(sq2); + v3 = (m33 - m34) * afp(sq2); + + v00 = m00 + s0 + s0 + u0; + v10 = m10 + s1 + s1 + u1; + v20 = m20 + s2 + s2 + u2; + v30 = m30 + s3 + s3 + u3; + + v01 = t0 + t0 + v0; + v11 = t1 + t1 + v1; + v21 = t2 + t2 + v2; + v31 = t3 + t3 + v3; + + v02 = s0 + u0 + u0; + v12 = s1 + u1 + u1; + v22 = s2 + u2 + u2; + v32 = s3 + u3 + u3; + + v03 = m05 + t0 + v0 + v0; + v13 = m15 + t1 + v1 + v1; + v23 = m25 + t2 + v2 + v2; + v33 = m35 + t3 + v3 + v3; if (bias_term == 1) { diff --git a/src/layer/vulkan/shader/convolution_pack4_3x3s1d1_winograd43_transform_input.comp b/src/layer/vulkan/shader/convolution_pack4_3x3s1d1_winograd43_transform_input.comp index fef0c1daf..7d6ebe41b 100644 --- a/src/layer/vulkan/shader/convolution_pack4_3x3s1d1_winograd43_transform_input.comp +++ b/src/layer/vulkan/shader/convolution_pack4_3x3s1d1_winograd43_transform_input.comp @@ -156,106 +156,158 @@ void main() afpvec4 v55 = sy + 5 < psc(h) && sx + 5 < psc(w) ? buffer_ld4(bottom_blob_data, v_offset45.y + 5) : afpvec4(0.f); #endif +#define sq2 1.41421356237 +#define sq2_d2 1.41421356237/2 + // const float itm[6][6] = { - // {1.0f, 0.0f, -1.25f, 0.0f, 0.25f, 0.0f}, - // {0.0f,-1.0f, -1.0f, 0.25f, 0.25f, 0.0f}, - // {0.0f, 1.0f, -1.0f,-0.25f, 0.25f, 0.0f}, - // {0.0f,-0.5f, -0.25f, 0.5f, 0.25f, 0.0f}, - // {0.0f, 0.5f, -0.25f,-0.5f, 0.25f, 0.0f}, - // {0.0f, 1.0f, 0.0f,-1.25f, 0.0f, 0.25f} + // {1.0f, 0.0f, -2.5f, 0.0f, 1.0f, 0.0f}, + // {0.0f, -sq2, -2.0f, sq2/2, 1.0f, 0.0f}, + // {0.0f, sq2, -2.0f, sq2/2, 1.0f, 0.0f}, + // {0.0f, -sq2/2, -0.5f, sq2, 1.0f, 0.0f}, + // {0.0f, sq2/2, -0.5f, -sq2, 1.0f, 0.0f}, + // {0.0f, 1.0f, 0.0f, -2.5f, 0.0f, 1.0f} // }; - // 0 = 1 * r00 - 1.25 * r02 + 0.25 * r04 - // 1 = -1 * (r01 + r02) + 0.25 * (r04 + r03) - // 2 = 1 * (r01 - r02) + 0.25 * (r04 - r03) - // 3 = -0.5 * (r01 - r03) + 0.25 * (r04 - r02) - // 4 = 0.5 * (r01 - r03) + 0.25 * (r04 - r02) - // 5 = 1 * r01 - 1.25 * r03 + 0.25 * r05 - // implicit transpose - afpvec4 m00 = v00 - v02 * afp(1.25) + v04 * afp(0.25); - afpvec4 m01 = v10 - v12 * afp(1.25) + v14 * afp(0.25); - afpvec4 m02 = v20 - v22 * afp(1.25) + v24 * afp(0.25); - afpvec4 m03 = v30 - v32 * afp(1.25) + v34 * afp(0.25); - afpvec4 m04 = v40 - v42 * afp(1.25) + v44 * afp(0.25); - afpvec4 m05 = v50 - v52 * afp(1.25) + v54 * afp(0.25); - - afpvec4 m10 = (v04 + v03) * afp(0.25) - (v01 + v02); - afpvec4 m11 = (v14 + v13) * afp(0.25) - (v11 + v12); - afpvec4 m12 = (v24 + v23) * afp(0.25) - (v21 + v22); - afpvec4 m13 = (v34 + v33) * afp(0.25) - (v31 + v32); - afpvec4 m14 = (v44 + v43) * afp(0.25) - (v41 + v42); - afpvec4 m15 = (v54 + v53) * afp(0.25) - (v51 + v52); - - afpvec4 m20 = (v04 - v03) * afp(0.25) + (v01 - v02); - afpvec4 m21 = (v14 - v13) * afp(0.25) + (v11 - v12); - afpvec4 m22 = (v24 - v23) * afp(0.25) + (v21 - v22); - afpvec4 m23 = (v34 - v33) * afp(0.25) + (v31 - v32); - afpvec4 m24 = (v44 - v43) * afp(0.25) + (v41 - v42); - afpvec4 m25 = (v54 - v53) * afp(0.25) + (v51 - v52); - - afpvec4 m30 = (v04 - v02) * afp(0.25) - (v01 - v03) * afp(0.5); - afpvec4 m31 = (v14 - v12) * afp(0.25) - (v11 - v13) * afp(0.5); - afpvec4 m32 = (v24 - v22) * afp(0.25) - (v21 - v23) * afp(0.5); - afpvec4 m33 = (v34 - v32) * afp(0.25) - (v31 - v33) * afp(0.5); - afpvec4 m34 = (v44 - v42) * afp(0.25) - (v41 - v43) * afp(0.5); - afpvec4 m35 = (v54 - v52) * afp(0.25) - (v51 - v53) * afp(0.5); - - afpvec4 m40 = (v04 - v02) * afp(0.25) + (v01 - v03) * afp(0.5); - afpvec4 m41 = (v14 - v12) * afp(0.25) + (v11 - v13) * afp(0.5); - afpvec4 m42 = (v24 - v22) * afp(0.25) + (v21 - v23) * afp(0.5); - afpvec4 m43 = (v34 - v32) * afp(0.25) + (v31 - v33) * afp(0.5); - afpvec4 m44 = (v44 - v42) * afp(0.25) + (v41 - v43) * afp(0.5); - afpvec4 m45 = (v54 - v52) * afp(0.25) + (v51 - v53) * afp(0.5); - - afpvec4 m50 = v01 - v03 * afp(1.25) + v05 * afp(0.25); - afpvec4 m51 = v11 - v13 * afp(1.25) + v15 * afp(0.25); - afpvec4 m52 = v21 - v23 * afp(1.25) + v25 * afp(0.25); - afpvec4 m53 = v31 - v33 * afp(1.25) + v35 * afp(0.25); - afpvec4 m54 = v41 - v43 * afp(1.25) + v45 * afp(0.25); - afpvec4 m55 = v51 - v53 * afp(1.25) + v55 * afp(0.25); - - v00 = m00 - m02 * afp(1.25) + m04 * afp(0.25); - v10 = m10 - m12 * afp(1.25) + m14 * afp(0.25); - v20 = m20 - m22 * afp(1.25) + m24 * afp(0.25); - v30 = m30 - m32 * afp(1.25) + m34 * afp(0.25); - v40 = m40 - m42 * afp(1.25) + m44 * afp(0.25); - v50 = m50 - m52 * afp(1.25) + m54 * afp(0.25); - - v01 = (m04 + m03) * afp(0.25) - (m01 + m02); - v11 = (m14 + m13) * afp(0.25) - (m11 + m12); - v21 = (m24 + m23) * afp(0.25) - (m21 + m22); - v31 = (m34 + m33) * afp(0.25) - (m31 + m32); - v41 = (m44 + m43) * afp(0.25) - (m41 + m42); - v51 = (m54 + m53) * afp(0.25) - (m51 + m52); - - v02 = (m04 - m03) * afp(0.25) + (m01 - m02); - v12 = (m14 - m13) * afp(0.25) + (m11 - m12); - v22 = (m24 - m23) * afp(0.25) + (m21 - m22); - v32 = (m34 - m33) * afp(0.25) + (m31 - m32); - v42 = (m44 - m43) * afp(0.25) + (m41 - m42); - v52 = (m54 - m53) * afp(0.25) + (m51 - m52); - - v03 = (m04 - m02) * afp(0.25) - (m01 - m03) * afp(0.5); - v13 = (m14 - m12) * afp(0.25) - (m11 - m13) * afp(0.5); - v23 = (m24 - m22) * afp(0.25) - (m21 - m23) * afp(0.5); - v33 = (m34 - m32) * afp(0.25) - (m31 - m33) * afp(0.5); - v43 = (m44 - m42) * afp(0.25) - (m41 - m43) * afp(0.5); - v53 = (m54 - m52) * afp(0.25) - (m51 - m53) * afp(0.5); - - v04 = (m04 - m02) * afp(0.25) + (m01 - m03) * afp(0.5); - v14 = (m14 - m12) * afp(0.25) + (m11 - m13) * afp(0.5); - v24 = (m24 - m22) * afp(0.25) + (m21 - m23) * afp(0.5); - v34 = (m34 - m32) * afp(0.25) + (m31 - m33) * afp(0.5); - v44 = (m44 - m42) * afp(0.25) + (m41 - m43) * afp(0.5); - v54 = (m54 - m52) * afp(0.25) + (m51 - m53) * afp(0.5); - - v05 = m01 - m03 * afp(1.25) + m05 * afp(0.25); - v15 = m11 - m13 * afp(1.25) + m15 * afp(0.25); - v25 = m21 - m23 * afp(1.25) + m25 * afp(0.25); - v35 = m31 - m33 * afp(1.25) + m35 * afp(0.25); - v45 = m41 - m43 * afp(1.25) + m45 * afp(0.25); - v55 = m51 - m53 * afp(1.25) + m55 * afp(0.25); + afpvec4 m00 = v00 - v02 * afp(2.5) + v04; + afpvec4 m01 = v10 - v12 * afp(2.5) + v14; + afpvec4 m02 = v20 - v22 * afp(2.5) + v24; + afpvec4 m03 = v30 - v32 * afp(2.5) + v34; + afpvec4 m04 = v40 - v42 * afp(2.5) + v44; + afpvec4 m05 = v50 - v52 * afp(2.5) + v54; + + afpvec4 s0 = v03 * afp(sq2_d2) - v01 * afp(sq2); + afpvec4 s1 = v13 * afp(sq2_d2) - v11 * afp(sq2); + afpvec4 s2 = v23 * afp(sq2_d2) - v21 * afp(sq2); + afpvec4 s3 = v33 * afp(sq2_d2) - v31 * afp(sq2); + afpvec4 s4 = v43 * afp(sq2_d2) - v41 * afp(sq2); + afpvec4 s5 = v53 * afp(sq2_d2) - v51 * afp(sq2); + + afpvec4 t0 = v04 - v02 * afp(2); + afpvec4 t1 = v14 - v12 * afp(2); + afpvec4 t2 = v24 - v22 * afp(2); + afpvec4 t3 = v34 - v32 * afp(2); + afpvec4 t4 = v44 - v42 * afp(2); + afpvec4 t5 = v54 - v52 * afp(2); + + afpvec4 m10 = t0 + s0; + afpvec4 m11 = t1 + s1; + afpvec4 m12 = t2 + s2; + afpvec4 m13 = t3 + s3; + afpvec4 m14 = t4 + s4; + afpvec4 m15 = t5 + s5; + + afpvec4 m20 = t0 - s0; + afpvec4 m21 = t1 - s1; + afpvec4 m22 = t2 - s2; + afpvec4 m23 = t3 - s3; + afpvec4 m24 = t4 - s4; + afpvec4 m25 = t5 - s5; + + s0 = v03 * afp(sq2) - v01 * afp(sq2_d2); + s1 = v13 * afp(sq2) - v11 * afp(sq2_d2); + s2 = v23 * afp(sq2) - v21 * afp(sq2_d2); + s3 = v33 * afp(sq2) - v31 * afp(sq2_d2); + s4 = v43 * afp(sq2) - v41 * afp(sq2_d2); + s5 = v53 * afp(sq2) - v51 * afp(sq2_d2); + + t0 = v04 - v02 * afp(0.5); + t1 = v14 - v12 * afp(0.5); + t2 = v24 - v22 * afp(0.5); + t3 = v34 - v32 * afp(0.5); + t4 = v44 - v42 * afp(0.5); + t5 = v54 - v52 * afp(0.5); + + afpvec4 m30 = t0 + s0; + afpvec4 m31 = t1 + s1; + afpvec4 m32 = t2 + s2; + afpvec4 m33 = t3 + s3; + afpvec4 m34 = t4 + s4; + afpvec4 m35 = t5 + s5; + + afpvec4 m40 = t0 - s0; + afpvec4 m41 = t1 - s1; + afpvec4 m42 = t2 - s2; + afpvec4 m43 = t3 - s3; + afpvec4 m44 = t4 - s4; + afpvec4 m45 = t5 - s5; + + afpvec4 m50 = v01 - v03 * afp(2.5) + v05; + afpvec4 m51 = v11 - v13 * afp(2.5) + v15; + afpvec4 m52 = v21 - v23 * afp(2.5) + v25; + afpvec4 m53 = v31 - v33 * afp(2.5) + v35; + afpvec4 m54 = v41 - v43 * afp(2.5) + v45; + afpvec4 m55 = v51 - v53 * afp(2.5) + v55; + + v00 = m00 - m02 * afp(2.5) + m04; + v10 = m10 - m12 * afp(2.5) + m14; + v20 = m20 - m22 * afp(2.5) + m24; + v30 = m30 - m32 * afp(2.5) + m34; + v40 = m40 - m42 * afp(2.5) + m44; + v50 = m50 - m52 * afp(2.5) + m54; + + s0 = m03 * afp(sq2_d2) - m01 * afp(sq2); + s1 = m13 * afp(sq2_d2) - m11 * afp(sq2); + s2 = m23 * afp(sq2_d2) - m21 * afp(sq2); + s3 = m33 * afp(sq2_d2) - m31 * afp(sq2); + s4 = m43 * afp(sq2_d2) - m41 * afp(sq2); + s5 = m53 * afp(sq2_d2) - m51 * afp(sq2); + + t0 = m04 - m02 * afp(2); + t1 = m14 - m12 * afp(2); + t2 = m24 - m22 * afp(2); + t3 = m34 - m32 * afp(2); + t4 = m44 - m42 * afp(2); + t5 = m54 - m52 * afp(2); + + v01 = t0 + s0; + v11 = t1 + s1; + v21 = t2 + s2; + v31 = t3 + s3; + v41 = t4 + s4; + v51 = t5 + s5; + + v02 = t0 - s0; + v12 = t1 - s1; + v22 = t2 - s2; + v32 = t3 - s3; + v42 = t4 - s4; + v52 = t5 - s5; + + s0 = m03 * afp(sq2) - m01 * afp(sq2_d2); + s1 = m13 * afp(sq2) - m11 * afp(sq2_d2); + s2 = m23 * afp(sq2) - m21 * afp(sq2_d2); + s3 = m33 * afp(sq2) - m31 * afp(sq2_d2); + s4 = m43 * afp(sq2) - m41 * afp(sq2_d2); + s5 = m53 * afp(sq2) - m51 * afp(sq2_d2); + + t0 = m04 - m02 * afp(0.5); + t1 = m14 - m12 * afp(0.5); + t2 = m24 - m22 * afp(0.5); + t3 = m34 - m32 * afp(0.5); + t4 = m44 - m42 * afp(0.5); + t5 = m54 - m52 * afp(0.5); + + v03 = t0 + s0; + v13 = t1 + s1; + v23 = t2 + s2; + v33 = t3 + s3; + v43 = t4 + s4; + v53 = t5 + s5; + + v04 = t0 - s0; + v14 = t1 - s1; + v24 = t2 - s2; + v34 = t3 - s3; + v44 = t4 - s4; + v54 = t5 - s5; + + v05 = m01 - m03 * afp(2.5) + m05; + v15 = m11 - m13 * afp(2.5) + m15; + v25 = m21 - m23 * afp(2.5) + m25; + v35 = m31 - m33 * afp(2.5) + m35; + v45 = m41 - m43 * afp(2.5) + m45; + v55 = m51 - m53 * afp(2.5) + m55; // store 36 #if NCNN_image_shader diff --git a/src/layer/vulkan/shader/convolution_pack4_3x3s1d1_winograd43_transform_output.comp b/src/layer/vulkan/shader/convolution_pack4_3x3s1d1_winograd43_transform_output.comp index dc1cdf337..10715d6af 100644 --- a/src/layer/vulkan/shader/convolution_pack4_3x3s1d1_winograd43_transform_output.comp +++ b/src/layer/vulkan/shader/convolution_pack4_3x3s1d1_winograd43_transform_output.comp @@ -153,66 +153,112 @@ void main() afpvec4 v55 = buffer_ld4(top_tm_blob_data, v_tm_offset + 35 * psc(cstep)); #endif +#define sq2 1.41421356237 +#define sq2_d4 1.41421356237/4 + // const float otm[4][6] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} + // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, + // {0.0f, sq2/2, -sq2/2, sq2, -sq2, 0.0f}, + // {0.0f, 0.5f, 0.5f, 2.0f, 2.0f, 0.0f}, + // {0.0f, sq2/4, -sq2/4, sq2*2, -sq2*2, 1.0f} // }; - // 0 = r00 + (r01 + r02) + (r03 + r04) - // 1 = (r01 - r02) + (r03 - r04) * 2 - // 2 = (r01 + r02) + (r03 + r04) * 4 - // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 - // implicit transpose - afpvec4 m00 = v00 + (v01 + v02) + (v03 + v04); - afpvec4 m01 = v10 + (v11 + v12) + (v13 + v14); - afpvec4 m02 = v20 + (v21 + v22) + (v23 + v24); - afpvec4 m03 = v30 + (v31 + v32) + (v33 + v34); - afpvec4 m04 = v40 + (v41 + v42) + (v43 + v44); - afpvec4 m05 = v50 + (v51 + v52) + (v53 + v54); - - afpvec4 m10 = (v01 - v02) + (v03 - v04) * afp(2); - afpvec4 m11 = (v11 - v12) + (v13 - v14) * afp(2); - afpvec4 m12 = (v21 - v22) + (v23 - v24) * afp(2); - afpvec4 m13 = (v31 - v32) + (v33 - v34) * afp(2); - afpvec4 m14 = (v41 - v42) + (v43 - v44) * afp(2); - afpvec4 m15 = (v51 - v52) + (v53 - v54) * afp(2); - - afpvec4 m20 = (v01 + v02) + (v03 + v04) * afp(4); - afpvec4 m21 = (v11 + v12) + (v13 + v14) * afp(4); - afpvec4 m22 = (v21 + v22) + (v23 + v24) * afp(4); - afpvec4 m23 = (v31 + v32) + (v33 + v34) * afp(4); - afpvec4 m24 = (v41 + v42) + (v43 + v44) * afp(4); - afpvec4 m25 = (v51 + v52) + (v53 + v54) * afp(4); - - afpvec4 m30 = v05 + (v01 - v02) + (v03 - v04) * afp(8); - afpvec4 m31 = v15 + (v11 - v12) + (v13 - v14) * afp(8); - afpvec4 m32 = v25 + (v21 - v22) + (v23 - v24) * afp(8); - afpvec4 m33 = v35 + (v31 - v32) + (v33 - v34) * afp(8); - afpvec4 m34 = v45 + (v41 - v42) + (v43 - v44) * afp(8); - afpvec4 m35 = v55 + (v51 - v52) + (v53 - v54) * afp(8); - - v00 = m00 + (m01 + m02) + (m03 + m04); - v10 = m10 + (m11 + m12) + (m13 + m14); - v20 = m20 + (m21 + m22) + (m23 + m24); - v30 = m30 + (m31 + m32) + (m33 + m34); - - v01 = (m01 - m02) + (m03 - m04) * afp(2); - v11 = (m11 - m12) + (m13 - m14) * afp(2); - v21 = (m21 - m22) + (m23 - m24) * afp(2); - v31 = (m31 - m32) + (m33 - m34) * afp(2); - - v02 = (m01 + m02) + (m03 + m04) * afp(4); - v12 = (m11 + m12) + (m13 + m14) * afp(4); - v22 = (m21 + m22) + (m23 + m24) * afp(4); - v32 = (m31 + m32) + (m33 + m34) * afp(4); - - v03 = m05 + (m01 - m02) + (m03 - m04) * afp(8); - v13 = m15 + (m11 - m12) + (m13 - m14) * afp(8); - v23 = m25 + (m21 - m22) + (m23 - m24) * afp(8); - v33 = m35 + (m31 - m32) + (m33 - m34) * afp(8); + afpvec4 s0 = (v01 + v02) * afp(0.5); + afpvec4 s1 = (v11 + v12) * afp(0.5); + afpvec4 s2 = (v21 + v22) * afp(0.5); + afpvec4 s3 = (v31 + v32) * afp(0.5); + afpvec4 s4 = (v41 + v42) * afp(0.5); + afpvec4 s5 = (v51 + v52) * afp(0.5); + + afpvec4 t0 = (v01 - v02) * afp(sq2_d4); + afpvec4 t1 = (v11 - v12) * afp(sq2_d4); + afpvec4 t2 = (v21 - v22) * afp(sq2_d4); + afpvec4 t3 = (v31 - v32) * afp(sq2_d4); + afpvec4 t4 = (v41 - v42) * afp(sq2_d4); + afpvec4 t5 = (v51 - v52) * afp(sq2_d4); + + afpvec4 u0 = v03 + v04; + afpvec4 u1 = v13 + v14; + afpvec4 u2 = v23 + v24; + afpvec4 u3 = v33 + v34; + afpvec4 u4 = v43 + v44; + afpvec4 u5 = v53 + v54; + + afpvec4 v0 = (v03 - v04) * afp(sq2); + afpvec4 v1 = (v13 - v14) * afp(sq2); + afpvec4 v2 = (v23 - v24) * afp(sq2); + afpvec4 v3 = (v33 - v34) * afp(sq2); + afpvec4 v4 = (v43 - v44) * afp(sq2); + afpvec4 v5 = (v53 - v54) * afp(sq2); + + afpvec4 m00 = v00 + s0 + s0 + u0; + afpvec4 m01 = v10 + s1 + s1 + u1; + afpvec4 m02 = v20 + s2 + s2 + u2; + afpvec4 m03 = v30 + s3 + s3 + u3; + afpvec4 m04 = v40 + s4 + s4 + u4; + afpvec4 m05 = v50 + s5 + s5 + u5; + + afpvec4 m10 = t0 + t0 + v0; + afpvec4 m11 = t1 + t1 + v1; + afpvec4 m12 = t2 + t2 + v2; + afpvec4 m13 = t3 + t3 + v3; + afpvec4 m14 = t4 + t4 + v4; + afpvec4 m15 = t5 + t5 + v5; + + afpvec4 m20 = s0 + u0 + u0; + afpvec4 m21 = s1 + u1 + u1; + afpvec4 m22 = s2 + u2 + u2; + afpvec4 m23 = s3 + u3 + u3; + afpvec4 m24 = s4 + u4 + u4; + afpvec4 m25 = s5 + u5 + u5; + + afpvec4 m30 = v05 + t0 + v0 + v0; + afpvec4 m31 = v15 + t1 + v1 + v1; + afpvec4 m32 = v25 + t2 + v2 + v2; + afpvec4 m33 = v35 + t3 + v3 + v3; + afpvec4 m34 = v45 + t4 + v4 + v4; + afpvec4 m35 = v55 + t5 + v5 + v5; + + s0 = (m01 + m02) * afp(0.5); + s1 = (m11 + m12) * afp(0.5); + s2 = (m21 + m22) * afp(0.5); + s3 = (m31 + m32) * afp(0.5); + + t0 = (m01 - m02) * afp(sq2_d4); + t1 = (m11 - m12) * afp(sq2_d4); + t2 = (m21 - m22) * afp(sq2_d4); + t3 = (m31 - m32) * afp(sq2_d4); + + u0 = m03 + m04; + u1 = m13 + m14; + u2 = m23 + m24; + u3 = m33 + m34; + + v0 = (m03 - m04) * afp(sq2); + v1 = (m13 - m14) * afp(sq2); + v2 = (m23 - m24) * afp(sq2); + v3 = (m33 - m34) * afp(sq2); + + v00 = m00 + s0 + s0 + u0; + v10 = m10 + s1 + s1 + u1; + v20 = m20 + s2 + s2 + u2; + v30 = m30 + s3 + s3 + u3; + + v01 = t0 + t0 + v0; + v11 = t1 + t1 + v1; + v21 = t2 + t2 + v2; + v31 = t3 + t3 + v3; + + v02 = s0 + u0 + u0; + v12 = s1 + u1 + u1; + v22 = s2 + u2 + u2; + v32 = s3 + u3 + u3; + + v03 = m05 + t0 + v0 + v0; + v13 = m15 + t1 + v1 + v1; + v23 = m25 + t2 + v2 + v2; + v33 = m35 + t3 + v3 + v3; if (bias_term == 1) { diff --git a/src/layer/vulkan/shader/convolution_pack8_3x3s1d1_winograd43_transform_input.comp b/src/layer/vulkan/shader/convolution_pack8_3x3s1d1_winograd43_transform_input.comp index 25c538b84..3042679d3 100644 --- a/src/layer/vulkan/shader/convolution_pack8_3x3s1d1_winograd43_transform_input.comp +++ b/src/layer/vulkan/shader/convolution_pack8_3x3s1d1_winograd43_transform_input.comp @@ -157,106 +157,158 @@ void main() afpvec8 v55 = sy + 5 < psc(h) && sx + 5 < psc(w) ? buffer_ld8(bottom_blob_data, v_offset45.y + 5) : afpvec8(afpvec4(0.f), afpvec4(0.f)); #endif +#define sq2 1.41421356237 +#define sq2_d2 1.41421356237/2 + // const float itm[6][6] = { - // {1.0f, 0.0f, -1.25f, 0.0f, 0.25f, 0.0f}, - // {0.0f,-1.0f, -1.0f, 0.25f, 0.25f, 0.0f}, - // {0.0f, 1.0f, -1.0f,-0.25f, 0.25f, 0.0f}, - // {0.0f,-0.5f, -0.25f, 0.5f, 0.25f, 0.0f}, - // {0.0f, 0.5f, -0.25f,-0.5f, 0.25f, 0.0f}, - // {0.0f, 1.0f, 0.0f,-1.25f, 0.0f, 0.25f} + // {1.0f, 0.0f, -2.5f, 0.0f, 1.0f, 0.0f}, + // {0.0f, -sq2, -2.0f, sq2/2, 1.0f, 0.0f}, + // {0.0f, sq2, -2.0f, sq2/2, 1.0f, 0.0f}, + // {0.0f, -sq2/2, -0.5f, sq2, 1.0f, 0.0f}, + // {0.0f, sq2/2, -0.5f, -sq2, 1.0f, 0.0f}, + // {0.0f, 1.0f, 0.0f, -2.5f, 0.0f, 1.0f} // }; - // 0 = 1 * r00 - 1.25 * r02 + 0.25 * r04 - // 1 = -1 * (r01 + r02) + 0.25 * (r04 + r03) - // 2 = 1 * (r01 - r02) + 0.25 * (r04 - r03) - // 3 = -0.5 * (r01 - r03) + 0.25 * (r04 - r02) - // 4 = 0.5 * (r01 - r03) + 0.25 * (r04 - r02) - // 5 = 1 * r01 - 1.25 * r03 + 0.25 * r05 - // implicit transpose - afpvec8 m00 = v00 - v02 * afp(1.25) + v04 * afp(0.25); - afpvec8 m01 = v10 - v12 * afp(1.25) + v14 * afp(0.25); - afpvec8 m02 = v20 - v22 * afp(1.25) + v24 * afp(0.25); - afpvec8 m03 = v30 - v32 * afp(1.25) + v34 * afp(0.25); - afpvec8 m04 = v40 - v42 * afp(1.25) + v44 * afp(0.25); - afpvec8 m05 = v50 - v52 * afp(1.25) + v54 * afp(0.25); - - afpvec8 m10 = (v04 + v03) * afp(0.25) - (v01 + v02); - afpvec8 m11 = (v14 + v13) * afp(0.25) - (v11 + v12); - afpvec8 m12 = (v24 + v23) * afp(0.25) - (v21 + v22); - afpvec8 m13 = (v34 + v33) * afp(0.25) - (v31 + v32); - afpvec8 m14 = (v44 + v43) * afp(0.25) - (v41 + v42); - afpvec8 m15 = (v54 + v53) * afp(0.25) - (v51 + v52); - - afpvec8 m20 = (v04 - v03) * afp(0.25) + (v01 - v02); - afpvec8 m21 = (v14 - v13) * afp(0.25) + (v11 - v12); - afpvec8 m22 = (v24 - v23) * afp(0.25) + (v21 - v22); - afpvec8 m23 = (v34 - v33) * afp(0.25) + (v31 - v32); - afpvec8 m24 = (v44 - v43) * afp(0.25) + (v41 - v42); - afpvec8 m25 = (v54 - v53) * afp(0.25) + (v51 - v52); - - afpvec8 m30 = (v04 - v02) * afp(0.25) - (v01 - v03) * afp(0.5); - afpvec8 m31 = (v14 - v12) * afp(0.25) - (v11 - v13) * afp(0.5); - afpvec8 m32 = (v24 - v22) * afp(0.25) - (v21 - v23) * afp(0.5); - afpvec8 m33 = (v34 - v32) * afp(0.25) - (v31 - v33) * afp(0.5); - afpvec8 m34 = (v44 - v42) * afp(0.25) - (v41 - v43) * afp(0.5); - afpvec8 m35 = (v54 - v52) * afp(0.25) - (v51 - v53) * afp(0.5); - - afpvec8 m40 = (v04 - v02) * afp(0.25) + (v01 - v03) * afp(0.5); - afpvec8 m41 = (v14 - v12) * afp(0.25) + (v11 - v13) * afp(0.5); - afpvec8 m42 = (v24 - v22) * afp(0.25) + (v21 - v23) * afp(0.5); - afpvec8 m43 = (v34 - v32) * afp(0.25) + (v31 - v33) * afp(0.5); - afpvec8 m44 = (v44 - v42) * afp(0.25) + (v41 - v43) * afp(0.5); - afpvec8 m45 = (v54 - v52) * afp(0.25) + (v51 - v53) * afp(0.5); - - afpvec8 m50 = v01 - v03 * afp(1.25) + v05 * afp(0.25); - afpvec8 m51 = v11 - v13 * afp(1.25) + v15 * afp(0.25); - afpvec8 m52 = v21 - v23 * afp(1.25) + v25 * afp(0.25); - afpvec8 m53 = v31 - v33 * afp(1.25) + v35 * afp(0.25); - afpvec8 m54 = v41 - v43 * afp(1.25) + v45 * afp(0.25); - afpvec8 m55 = v51 - v53 * afp(1.25) + v55 * afp(0.25); - - v00 = m00 - m02 * afp(1.25) + m04 * afp(0.25); - v10 = m10 - m12 * afp(1.25) + m14 * afp(0.25); - v20 = m20 - m22 * afp(1.25) + m24 * afp(0.25); - v30 = m30 - m32 * afp(1.25) + m34 * afp(0.25); - v40 = m40 - m42 * afp(1.25) + m44 * afp(0.25); - v50 = m50 - m52 * afp(1.25) + m54 * afp(0.25); - - v01 = (m04 + m03) * afp(0.25) - (m01 + m02); - v11 = (m14 + m13) * afp(0.25) - (m11 + m12); - v21 = (m24 + m23) * afp(0.25) - (m21 + m22); - v31 = (m34 + m33) * afp(0.25) - (m31 + m32); - v41 = (m44 + m43) * afp(0.25) - (m41 + m42); - v51 = (m54 + m53) * afp(0.25) - (m51 + m52); - - v02 = (m04 - m03) * afp(0.25) + (m01 - m02); - v12 = (m14 - m13) * afp(0.25) + (m11 - m12); - v22 = (m24 - m23) * afp(0.25) + (m21 - m22); - v32 = (m34 - m33) * afp(0.25) + (m31 - m32); - v42 = (m44 - m43) * afp(0.25) + (m41 - m42); - v52 = (m54 - m53) * afp(0.25) + (m51 - m52); - - v03 = (m04 - m02) * afp(0.25) - (m01 - m03) * afp(0.5); - v13 = (m14 - m12) * afp(0.25) - (m11 - m13) * afp(0.5); - v23 = (m24 - m22) * afp(0.25) - (m21 - m23) * afp(0.5); - v33 = (m34 - m32) * afp(0.25) - (m31 - m33) * afp(0.5); - v43 = (m44 - m42) * afp(0.25) - (m41 - m43) * afp(0.5); - v53 = (m54 - m52) * afp(0.25) - (m51 - m53) * afp(0.5); - - v04 = (m04 - m02) * afp(0.25) + (m01 - m03) * afp(0.5); - v14 = (m14 - m12) * afp(0.25) + (m11 - m13) * afp(0.5); - v24 = (m24 - m22) * afp(0.25) + (m21 - m23) * afp(0.5); - v34 = (m34 - m32) * afp(0.25) + (m31 - m33) * afp(0.5); - v44 = (m44 - m42) * afp(0.25) + (m41 - m43) * afp(0.5); - v54 = (m54 - m52) * afp(0.25) + (m51 - m53) * afp(0.5); - - v05 = m01 - m03 * afp(1.25) + m05 * afp(0.25); - v15 = m11 - m13 * afp(1.25) + m15 * afp(0.25); - v25 = m21 - m23 * afp(1.25) + m25 * afp(0.25); - v35 = m31 - m33 * afp(1.25) + m35 * afp(0.25); - v45 = m41 - m43 * afp(1.25) + m45 * afp(0.25); - v55 = m51 - m53 * afp(1.25) + m55 * afp(0.25); + afpvec8 m00 = v00 - v02 * afp(2.5) + v04; + afpvec8 m01 = v10 - v12 * afp(2.5) + v14; + afpvec8 m02 = v20 - v22 * afp(2.5) + v24; + afpvec8 m03 = v30 - v32 * afp(2.5) + v34; + afpvec8 m04 = v40 - v42 * afp(2.5) + v44; + afpvec8 m05 = v50 - v52 * afp(2.5) + v54; + + afpvec8 s0 = v03 * afp(sq2_d2) - v01 * afp(sq2); + afpvec8 s1 = v13 * afp(sq2_d2) - v11 * afp(sq2); + afpvec8 s2 = v23 * afp(sq2_d2) - v21 * afp(sq2); + afpvec8 s3 = v33 * afp(sq2_d2) - v31 * afp(sq2); + afpvec8 s4 = v43 * afp(sq2_d2) - v41 * afp(sq2); + afpvec8 s5 = v53 * afp(sq2_d2) - v51 * afp(sq2); + + afpvec8 t0 = v04 - v02 * afp(2); + afpvec8 t1 = v14 - v12 * afp(2); + afpvec8 t2 = v24 - v22 * afp(2); + afpvec8 t3 = v34 - v32 * afp(2); + afpvec8 t4 = v44 - v42 * afp(2); + afpvec8 t5 = v54 - v52 * afp(2); + + afpvec8 m10 = t0 + s0; + afpvec8 m11 = t1 + s1; + afpvec8 m12 = t2 + s2; + afpvec8 m13 = t3 + s3; + afpvec8 m14 = t4 + s4; + afpvec8 m15 = t5 + s5; + + afpvec8 m20 = t0 - s0; + afpvec8 m21 = t1 - s1; + afpvec8 m22 = t2 - s2; + afpvec8 m23 = t3 - s3; + afpvec8 m24 = t4 - s4; + afpvec8 m25 = t5 - s5; + + s0 = v03 * afp(sq2) - v01 * afp(sq2_d2); + s1 = v13 * afp(sq2) - v11 * afp(sq2_d2); + s2 = v23 * afp(sq2) - v21 * afp(sq2_d2); + s3 = v33 * afp(sq2) - v31 * afp(sq2_d2); + s4 = v43 * afp(sq2) - v41 * afp(sq2_d2); + s5 = v53 * afp(sq2) - v51 * afp(sq2_d2); + + t0 = v04 - v02 * afp(0.5); + t1 = v14 - v12 * afp(0.5); + t2 = v24 - v22 * afp(0.5); + t3 = v34 - v32 * afp(0.5); + t4 = v44 - v42 * afp(0.5); + t5 = v54 - v52 * afp(0.5); + + afpvec8 m30 = t0 + s0; + afpvec8 m31 = t1 + s1; + afpvec8 m32 = t2 + s2; + afpvec8 m33 = t3 + s3; + afpvec8 m34 = t4 + s4; + afpvec8 m35 = t5 + s5; + + afpvec8 m40 = t0 - s0; + afpvec8 m41 = t1 - s1; + afpvec8 m42 = t2 - s2; + afpvec8 m43 = t3 - s3; + afpvec8 m44 = t4 - s4; + afpvec8 m45 = t5 - s5; + + afpvec8 m50 = v01 - v03 * afp(2.5) + v05; + afpvec8 m51 = v11 - v13 * afp(2.5) + v15; + afpvec8 m52 = v21 - v23 * afp(2.5) + v25; + afpvec8 m53 = v31 - v33 * afp(2.5) + v35; + afpvec8 m54 = v41 - v43 * afp(2.5) + v45; + afpvec8 m55 = v51 - v53 * afp(2.5) + v55; + + v00 = m00 - m02 * afp(2.5) + m04; + v10 = m10 - m12 * afp(2.5) + m14; + v20 = m20 - m22 * afp(2.5) + m24; + v30 = m30 - m32 * afp(2.5) + m34; + v40 = m40 - m42 * afp(2.5) + m44; + v50 = m50 - m52 * afp(2.5) + m54; + + s0 = m03 * afp(sq2_d2) - m01 * afp(sq2); + s1 = m13 * afp(sq2_d2) - m11 * afp(sq2); + s2 = m23 * afp(sq2_d2) - m21 * afp(sq2); + s3 = m33 * afp(sq2_d2) - m31 * afp(sq2); + s4 = m43 * afp(sq2_d2) - m41 * afp(sq2); + s5 = m53 * afp(sq2_d2) - m51 * afp(sq2); + + t0 = m04 - m02 * afp(2); + t1 = m14 - m12 * afp(2); + t2 = m24 - m22 * afp(2); + t3 = m34 - m32 * afp(2); + t4 = m44 - m42 * afp(2); + t5 = m54 - m52 * afp(2); + + v01 = t0 + s0; + v11 = t1 + s1; + v21 = t2 + s2; + v31 = t3 + s3; + v41 = t4 + s4; + v51 = t5 + s5; + + v02 = t0 - s0; + v12 = t1 - s1; + v22 = t2 - s2; + v32 = t3 - s3; + v42 = t4 - s4; + v52 = t5 - s5; + + s0 = m03 * afp(sq2) - m01 * afp(sq2_d2); + s1 = m13 * afp(sq2) - m11 * afp(sq2_d2); + s2 = m23 * afp(sq2) - m21 * afp(sq2_d2); + s3 = m33 * afp(sq2) - m31 * afp(sq2_d2); + s4 = m43 * afp(sq2) - m41 * afp(sq2_d2); + s5 = m53 * afp(sq2) - m51 * afp(sq2_d2); + + t0 = m04 - m02 * afp(0.5); + t1 = m14 - m12 * afp(0.5); + t2 = m24 - m22 * afp(0.5); + t3 = m34 - m32 * afp(0.5); + t4 = m44 - m42 * afp(0.5); + t5 = m54 - m52 * afp(0.5); + + v03 = t0 + s0; + v13 = t1 + s1; + v23 = t2 + s2; + v33 = t3 + s3; + v43 = t4 + s4; + v53 = t5 + s5; + + v04 = t0 - s0; + v14 = t1 - s1; + v24 = t2 - s2; + v34 = t3 - s3; + v44 = t4 - s4; + v54 = t5 - s5; + + v05 = m01 - m03 * afp(2.5) + m05; + v15 = m11 - m13 * afp(2.5) + m15; + v25 = m21 - m23 * afp(2.5) + m25; + v35 = m31 - m33 * afp(2.5) + m35; + v45 = m41 - m43 * afp(2.5) + m45; + v55 = m51 - m53 * afp(2.5) + m55; // store 36 #if NCNN_image_shader diff --git a/src/layer/vulkan/shader/convolution_pack8_3x3s1d1_winograd43_transform_output.comp b/src/layer/vulkan/shader/convolution_pack8_3x3s1d1_winograd43_transform_output.comp index d04b2e4cf..b957abf66 100644 --- a/src/layer/vulkan/shader/convolution_pack8_3x3s1d1_winograd43_transform_output.comp +++ b/src/layer/vulkan/shader/convolution_pack8_3x3s1d1_winograd43_transform_output.comp @@ -154,66 +154,112 @@ void main() afpvec8 v55 = buffer_ld8(top_tm_blob_data, v_tm_offset + 35 * psc(cstep)); #endif +#define sq2 1.41421356237 +#define sq2_d4 1.41421356237/4 + // const float otm[4][6] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} + // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, + // {0.0f, sq2/2, -sq2/2, sq2, -sq2, 0.0f}, + // {0.0f, 0.5f, 0.5f, 2.0f, 2.0f, 0.0f}, + // {0.0f, sq2/4, -sq2/4, sq2*2, -sq2*2, 1.0f} // }; - // 0 = r00 + (r01 + r02) + (r03 + r04) - // 1 = (r01 - r02) + (r03 - r04) * 2 - // 2 = (r01 + r02) + (r03 + r04) * 4 - // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 - // implicit transpose - afpvec8 m00 = v00 + (v01 + v02) + (v03 + v04); - afpvec8 m01 = v10 + (v11 + v12) + (v13 + v14); - afpvec8 m02 = v20 + (v21 + v22) + (v23 + v24); - afpvec8 m03 = v30 + (v31 + v32) + (v33 + v34); - afpvec8 m04 = v40 + (v41 + v42) + (v43 + v44); - afpvec8 m05 = v50 + (v51 + v52) + (v53 + v54); - - afpvec8 m10 = (v01 - v02) + (v03 - v04) * afp(2); - afpvec8 m11 = (v11 - v12) + (v13 - v14) * afp(2); - afpvec8 m12 = (v21 - v22) + (v23 - v24) * afp(2); - afpvec8 m13 = (v31 - v32) + (v33 - v34) * afp(2); - afpvec8 m14 = (v41 - v42) + (v43 - v44) * afp(2); - afpvec8 m15 = (v51 - v52) + (v53 - v54) * afp(2); - - afpvec8 m20 = (v01 + v02) + (v03 + v04) * afp(4); - afpvec8 m21 = (v11 + v12) + (v13 + v14) * afp(4); - afpvec8 m22 = (v21 + v22) + (v23 + v24) * afp(4); - afpvec8 m23 = (v31 + v32) + (v33 + v34) * afp(4); - afpvec8 m24 = (v41 + v42) + (v43 + v44) * afp(4); - afpvec8 m25 = (v51 + v52) + (v53 + v54) * afp(4); - - afpvec8 m30 = v05 + (v01 - v02) + (v03 - v04) * afp(8); - afpvec8 m31 = v15 + (v11 - v12) + (v13 - v14) * afp(8); - afpvec8 m32 = v25 + (v21 - v22) + (v23 - v24) * afp(8); - afpvec8 m33 = v35 + (v31 - v32) + (v33 - v34) * afp(8); - afpvec8 m34 = v45 + (v41 - v42) + (v43 - v44) * afp(8); - afpvec8 m35 = v55 + (v51 - v52) + (v53 - v54) * afp(8); - - v00 = m00 + (m01 + m02) + (m03 + m04); - v10 = m10 + (m11 + m12) + (m13 + m14); - v20 = m20 + (m21 + m22) + (m23 + m24); - v30 = m30 + (m31 + m32) + (m33 + m34); - - v01 = (m01 - m02) + (m03 - m04) * afp(2); - v11 = (m11 - m12) + (m13 - m14) * afp(2); - v21 = (m21 - m22) + (m23 - m24) * afp(2); - v31 = (m31 - m32) + (m33 - m34) * afp(2); - - v02 = (m01 + m02) + (m03 + m04) * afp(4); - v12 = (m11 + m12) + (m13 + m14) * afp(4); - v22 = (m21 + m22) + (m23 + m24) * afp(4); - v32 = (m31 + m32) + (m33 + m34) * afp(4); - - v03 = m05 + (m01 - m02) + (m03 - m04) * afp(8); - v13 = m15 + (m11 - m12) + (m13 - m14) * afp(8); - v23 = m25 + (m21 - m22) + (m23 - m24) * afp(8); - v33 = m35 + (m31 - m32) + (m33 - m34) * afp(8); + afpvec8 s0 = (v01 + v02) * afp(0.5); + afpvec8 s1 = (v11 + v12) * afp(0.5); + afpvec8 s2 = (v21 + v22) * afp(0.5); + afpvec8 s3 = (v31 + v32) * afp(0.5); + afpvec8 s4 = (v41 + v42) * afp(0.5); + afpvec8 s5 = (v51 + v52) * afp(0.5); + + afpvec8 t0 = (v01 - v02) * afp(sq2_d4); + afpvec8 t1 = (v11 - v12) * afp(sq2_d4); + afpvec8 t2 = (v21 - v22) * afp(sq2_d4); + afpvec8 t3 = (v31 - v32) * afp(sq2_d4); + afpvec8 t4 = (v41 - v42) * afp(sq2_d4); + afpvec8 t5 = (v51 - v52) * afp(sq2_d4); + + afpvec8 u0 = v03 + v04; + afpvec8 u1 = v13 + v14; + afpvec8 u2 = v23 + v24; + afpvec8 u3 = v33 + v34; + afpvec8 u4 = v43 + v44; + afpvec8 u5 = v53 + v54; + + afpvec8 v0 = (v03 - v04) * afp(sq2); + afpvec8 v1 = (v13 - v14) * afp(sq2); + afpvec8 v2 = (v23 - v24) * afp(sq2); + afpvec8 v3 = (v33 - v34) * afp(sq2); + afpvec8 v4 = (v43 - v44) * afp(sq2); + afpvec8 v5 = (v53 - v54) * afp(sq2); + + afpvec8 m00 = v00 + s0 + s0 + u0; + afpvec8 m01 = v10 + s1 + s1 + u1; + afpvec8 m02 = v20 + s2 + s2 + u2; + afpvec8 m03 = v30 + s3 + s3 + u3; + afpvec8 m04 = v40 + s4 + s4 + u4; + afpvec8 m05 = v50 + s5 + s5 + u5; + + afpvec8 m10 = t0 + t0 + v0; + afpvec8 m11 = t1 + t1 + v1; + afpvec8 m12 = t2 + t2 + v2; + afpvec8 m13 = t3 + t3 + v3; + afpvec8 m14 = t4 + t4 + v4; + afpvec8 m15 = t5 + t5 + v5; + + afpvec8 m20 = s0 + u0 + u0; + afpvec8 m21 = s1 + u1 + u1; + afpvec8 m22 = s2 + u2 + u2; + afpvec8 m23 = s3 + u3 + u3; + afpvec8 m24 = s4 + u4 + u4; + afpvec8 m25 = s5 + u5 + u5; + + afpvec8 m30 = v05 + t0 + v0 + v0; + afpvec8 m31 = v15 + t1 + v1 + v1; + afpvec8 m32 = v25 + t2 + v2 + v2; + afpvec8 m33 = v35 + t3 + v3 + v3; + afpvec8 m34 = v45 + t4 + v4 + v4; + afpvec8 m35 = v55 + t5 + v5 + v5; + + s0 = (m01 + m02) * afp(0.5); + s1 = (m11 + m12) * afp(0.5); + s2 = (m21 + m22) * afp(0.5); + s3 = (m31 + m32) * afp(0.5); + + t0 = (m01 - m02) * afp(sq2_d4); + t1 = (m11 - m12) * afp(sq2_d4); + t2 = (m21 - m22) * afp(sq2_d4); + t3 = (m31 - m32) * afp(sq2_d4); + + u0 = m03 + m04; + u1 = m13 + m14; + u2 = m23 + m24; + u3 = m33 + m34; + + v0 = (m03 - m04) * afp(sq2); + v1 = (m13 - m14) * afp(sq2); + v2 = (m23 - m24) * afp(sq2); + v3 = (m33 - m34) * afp(sq2); + + v00 = m00 + s0 + s0 + u0; + v10 = m10 + s1 + s1 + u1; + v20 = m20 + s2 + s2 + u2; + v30 = m30 + s3 + s3 + u3; + + v01 = t0 + t0 + v0; + v11 = t1 + t1 + v1; + v21 = t2 + t2 + v2; + v31 = t3 + t3 + v3; + + v02 = s0 + u0 + u0; + v12 = s1 + u1 + u1; + v22 = s2 + u2 + u2; + v32 = s3 + u3 + u3; + + v03 = m05 + t0 + v0 + v0; + v13 = m15 + t1 + v1 + v1; + v23 = m25 + t2 + v2 + v2; + v33 = m35 + t3 + v3 + v3; if (bias_term == 1) { diff --git a/tests/test_convolution.cpp b/tests/test_convolution.cpp index 9140750e8..9fbc13d34 100644 --- a/tests/test_convolution.cpp +++ b/tests/test_convolution.cpp @@ -41,16 +41,6 @@ static int test_convolution(int w, int h, int c, int outch, int kernel, int dila weights[1] = RandomMat(outch); float epsilon = 0.001; - // larget epsilon for winograd optimization - if (kernel == 3 && dilation == 1 && stride == 1 && c >= 16 && outch >= 16) - { - Randomize(a, -1, 1); - if (c >= 64 || outch >= 64) - Randomize(weights[0], -0.3, 0.3); - else - Randomize(weights[0], -1, 1); - epsilon = 0.002; - } int ret = test_layer("Convolution", pd, weights, a, epsilon); if (ret != 0) diff --git a/tests/test_convolution_1.cpp b/tests/test_convolution_1.cpp index b7641a4ea..b40413364 100644 --- a/tests/test_convolution_1.cpp +++ b/tests/test_convolution_1.cpp @@ -41,16 +41,6 @@ static int test_convolution(int w, int h, int c, int outch, int kernel, int dila weights[1] = RandomMat(outch); float epsilon = 0.001; - // larget epsilon for winograd optimization - if (kernel == 3 && dilation == 1 && stride == 1 && c >= 16 && outch >= 16) - { - Randomize(a, -1, 1); - if (c >= 64 || outch >= 64) - Randomize(weights[0], -0.3, 0.3); - else - Randomize(weights[0], -1, 1); - epsilon = 0.002; - } int ret = test_layer("Convolution", pd, weights, a, epsilon); if (ret != 0) diff --git a/tests/test_convolution_2.cpp b/tests/test_convolution_2.cpp index f30ea4e74..ca3ca9815 100644 --- a/tests/test_convolution_2.cpp +++ b/tests/test_convolution_2.cpp @@ -41,16 +41,6 @@ static int test_convolution(int w, int h, int c, int outch, int kernel, int dila weights[1] = RandomMat(outch); float epsilon = 0.001; - // larget epsilon for winograd optimization - if (kernel == 3 && dilation == 1 && stride == 1 && c >= 16 && outch >= 16) - { - Randomize(a, -1, 1); - if (c >= 64 || outch >= 64) - Randomize(weights[0], -0.3, 0.3); - else - Randomize(weights[0], -1, 1); - epsilon = 0.002; - } int ret = test_layer("Convolution", pd, weights, a, epsilon); if (ret != 0)