diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/conv2d.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/conv2d.cl index 674bc64c91..f31b4045c5 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/conv2d.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/conv2d.cl @@ -91,11 +91,11 @@ __kernel void Conv2D_H1W1C1(__read_only image2d_t input, __write_only image2d_t out_h0_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w0_c0)); } - if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) { - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); - } else { - WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); - } +#ifndef EXCEDD_MAX_IMAGE2D_WIDTH + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); +#else + WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); +#endif } __kernel void Conv2D_H2W1C1(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight, @@ -172,17 +172,17 @@ __kernel void Conv2D_H2W1C1(__read_only image2d_t input, __write_only image2d_t out_h1_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c0)); } - if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) { - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); - if (oh1 < OH) { - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0); - } // end if (oh1 < OH) - } else { - WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); - if (oh1 < OH) { - WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0); - } // end (oh1 < OH) - } +#ifndef EXCEDD_MAX_IMAGE2D_WIDTH + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); + if (oh1 < OH) { + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0); + } // end if (oh1 < OH) +#else + WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); + if (oh1 < OH) { + WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0); + } // end (oh1 < OH) +#endif } __kernel void Conv2D_H2W1C2(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight, @@ -283,29 +283,27 @@ __kernel void Conv2D_H2W1C2(__read_only image2d_t input, __write_only image2d_t out_h1_w0_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c1)); } - if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) { - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); +#ifndef EXCEDD_MAX_IMAGE2D_WIDTH + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); + if (oh1 < OH) { + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0); + } // end if (oh1 < OH) + if (co_slice1 < CO_SLICES) { + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1); if (oh1 < OH) { - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0); + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1); } // end if (oh1 < OH) - if (co_slice1 < CO_SLICES) { - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1); - if (oh1 < OH) { - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1); - } // end if (oh1 < OH) - } // end if (co_slice1 < CO_SLICES) - } else { - WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); - if (oh1 < OH) { - WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0); - } // end (oh1 < OH) - if (co_slice1 < CO_SLICES) { - WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1); - if (oh1 < OH) { - WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1); - } // end if (oh1 < OH) - } // end if (co_slice1 < CO_SLICES) - } + } // end if (co_slice1 < CO_SLICES) +#else + WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); + if (oh1 < OH) { + WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0); + } // end (oh1 < OH) + WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1); + if (oh1 < OH) { + WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1); + } // end if (oh1 < OH) +#endif } __kernel void Conv2D_H2W2C2(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight, @@ -456,37 +454,35 @@ __kernel void Conv2D_H2W2C2(__read_only image2d_t input, __write_only image2d_t out_h1_w1_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w1_c1)); } - if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) { - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); - WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh0), out_h0_w1_c0); +#ifndef EXCEDD_MAX_IMAGE2D_WIDTH + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); + WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh0), out_h0_w1_c0); + if (oh1 < OH) { + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0); + WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh1), out_h1_w1_c0); + } // end if (oh1 < OH) + if (co_slice1 < CO_SLICES) { + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1); + WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh0), out_h0_w1_c1); if (oh1 < OH) { - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0); - WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh1), out_h1_w1_c0); + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1); + WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh1), out_h1_w1_c1); } // end if (oh1 < OH) - if (co_slice1 < CO_SLICES) { - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1); - WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh0), out_h0_w1_c1); - if (oh1 < OH) { - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1); - WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh1), out_h1_w1_c1); - } // end if (oh1 < OH) - } // end if (co_slice1 < CO_SLICES) - } else { - WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); - WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow1), out_h0_w1_c0); - if (oh1 < OH) { - WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0); - WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow1), out_h1_w1_c0); - } // end (oh1 < OH) - if (co_slice1 < CO_SLICES) { - WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1); - WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow1), out_h0_w1_c1); - if (oh1 < OH) { - WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1); - WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow1), out_h1_w1_c1); - } // end if (oh1 < OH) - } // end if (co_slice1 < CO_SLICES) - } + } // end if (co_slice1 < CO_SLICES) +#else + WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); + WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow1), out_h0_w1_c0); + if (oh1 < OH) { + WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0); + WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow1), out_h1_w1_c0); + } // end (oh1 < OH) + WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1); + WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow1), out_h0_w1_c1); + if (oh1 < OH) { + WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1); + WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow1), out_h1_w1_c1); + } // end if (oh1 < OH) +#endif } __kernel void Conv2D_H2W2C2_Img(__read_only image2d_t input, __write_only image2d_t output, @@ -644,35 +640,33 @@ __kernel void Conv2D_H2W2C2_Img(__read_only image2d_t input, __write_only image2 out_h1_w1_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w1_c1)); } - if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) { - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); - WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh0), out_h0_w1_c0); +#ifndef EXCEDD_MAX_IMAGE2D_WIDTH + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); + WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh0), out_h0_w1_c0); + if (oh1 < OH) { + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0); + WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh1), out_h1_w1_c0); + } // end if (oh1 < OH) + if (co_slice1 < CO_SLICES) { + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1); + WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh0), out_h0_w1_c1); if (oh1 < OH) { - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0); - WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh1), out_h1_w1_c0); + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1); + WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh1), out_h1_w1_c1); } // end if (oh1 < OH) - if (co_slice1 < CO_SLICES) { - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1); - WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh0), out_h0_w1_c1); - if (oh1 < OH) { - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1); - WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh1), out_h1_w1_c1); - } // end if (oh1 < OH) - } // end if (co_slice1 < CO_SLICES) - } else { - WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); - WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow1), out_h0_w1_c0); - if (oh1 < OH) { - WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0); - WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow1), out_h1_w1_c0); - } // end (oh1 < OH) - if (co_slice1 < CO_SLICES) { - WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1); - WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow1), out_h0_w1_c1); - if (oh1 < OH) { - WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1); - WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow1), out_h1_w1_c1); - } // end if (oh1 < OH) - } // end if (co_slice1 < CO_SLICES) - } + } // end if (co_slice1 < CO_SLICES) +#else + WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); + WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow1), out_h0_w1_c0); + if (oh1 < OH) { + WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0); + WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow1), out_h1_w1_c0); + } // end (oh1 < OH) + WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1); + WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow1), out_h0_w1_c1); + if (oh1 < OH) { + WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1); + WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow1), out_h1_w1_c1); + } // end if (oh1 < OH) +#endif } diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/winograd.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/winograd.cl index 6eb7f96a9f..ca3b53af42 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/winograd.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/winograd.cl @@ -35,29 +35,79 @@ __kernel void Winograd4x4To36(__read_only image2d_t input, // height=N*H FLT4 BtD_row[6] = {0}; int h = tile_h * 4 - pad; int w = tile_w * 4 - pad; - for (int y = 0; y < 6; y++) { - int x_idx = w * CI_SLICES + ci_slice; - for (int x = 0; x < 6; x++) { - // no need to check w: because ci_slice is in [0, CI_SLICES). when w<0, x_idx<0; w>=W, x_idx>=W*CI_SLICES - // if (w < 0 || w >= W) { continue; } - BtD_row[x] += Bt_row[y] * READ_IMAGE(input, smp_zero, (int2)(x_idx, h)); - x_idx += CI_SLICES; - } - h++; - } - int y_idx = ci_slice * 36 + row * 6; - for (int y = 0; y < 6; y++) { - FLT4 acc = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); - for (int x = 0; x < 6; x++) { - acc += BtD_row[x] * Bt[y * 6 + x]; - } + int x_idx = w * CI_SLICES + ci_slice; + FLT bt0 = Bt_row[0], bt1 = Bt_row[1], bt2 = Bt_row[2], bt3 = Bt_row[3], bt4 = Bt_row[4], bt5 = Bt_row[5]; + BtD_row[0] = + bt0 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 0)) + bt1 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 1)) + + bt2 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 2)) + bt3 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 3)) + + bt4 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 4)) + bt5 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 5)); + x_idx += CI_SLICES; + BtD_row[1] = + bt0 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 0)) + bt1 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 1)) + + bt2 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 2)) + bt3 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 3)) + + bt4 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 4)) + bt5 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 5)); + x_idx += CI_SLICES; + BtD_row[2] = + bt0 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 0)) + bt1 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 1)) + + bt2 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 2)) + bt3 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 3)) + + bt4 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 4)) + bt5 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 5)); + x_idx += CI_SLICES; + BtD_row[3] = + bt0 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 0)) + bt1 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 1)) + + bt2 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 2)) + bt3 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 3)) + + bt4 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 4)) + bt5 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 5)); + x_idx += CI_SLICES; + BtD_row[4] = + bt0 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 0)) + bt1 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 1)) + + bt2 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 2)) + bt3 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 3)) + + bt4 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 4)) + bt5 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 5)); + x_idx += CI_SLICES; + BtD_row[5] = + bt0 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 0)) + bt1 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 1)) + + bt2 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 2)) + bt3 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 3)) + + bt4 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 4)) + bt5 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 5)); + #if FP16_ENABLE - acc = min(acc, HALF_MAX); - acc = max(acc, -HALF_MAX); +#ifndef HALF_MAX // adreno not exist +#define HALF_MAX 0x1.ffcp15h #endif - WRITE_IMAGE(output, (int2)(tile_hw, y_idx + y), acc); - } +#define LimitAcc() \ + acc = min(acc, HALF_MAX); \ + acc = max(acc, -HALF_MAX); +#else +#define LimitAcc() \ + {} +#endif + + int y_idx = ci_slice * 36 + row * 6; + FLT4 acc = BtD_row[0] + (FLT)(-2.5f) * BtD_row[2] + BtD_row[4]; + LimitAcc(); + WRITE_IMAGE(output, (int2)(tile_hw, y_idx++), acc); + + FLT4 tmp0 = (FLT)(0.9428091049f) * BtD_row[1] + (FLT)(-0.4714044929f) * BtD_row[3]; + FLT4 tmp1 = (FLT)(1.3333333731f) * BtD_row[2] + (FLT)(-0.6666667461f) * BtD_row[4]; + acc = tmp0 + tmp1; + LimitAcc(); + WRITE_IMAGE(output, (int2)(tile_hw, y_idx++), acc); + + acc = -tmp0 + tmp1; + LimitAcc(); + WRITE_IMAGE(output, (int2)(tile_hw, y_idx++), acc); + + tmp0 = (FLT)(-0.1178511307f) * BtD_row[1] + (FLT)(0.2357022613f) * BtD_row[3]; + tmp1 = (FLT)(-0.0833333358f) * BtD_row[2] + (FLT)(0.1666666865f) * BtD_row[4]; + acc = tmp0 + tmp1; + LimitAcc(); + WRITE_IMAGE(output, (int2)(tile_hw, y_idx++), acc); + + acc = -tmp0 + tmp1; + LimitAcc(); + WRITE_IMAGE(output, (int2)(tile_hw, y_idx++), acc); + + acc = BtD_row[1] + (FLT)(-2.5f) * BtD_row[3] + BtD_row[5]; + LimitAcc(); + WRITE_IMAGE(output, (int2)(tile_hw, y_idx++), acc); } __kernel void WinogradConv2D(__read_only image2d_t input, // height=CI_SLICES*36 width=TILE_HW @@ -181,6 +231,22 @@ constant FLT At[24] = {1.0000000000f, 1.0000000000f, 1.0000000000f, 1.000000000 0.0000000000f, 0.4999999702f, 0.4999999702f, 1.9999998808f, 1.9999998808f, 0.0000000000f, 0.0000000000f, 0.3535533845f, -0.3535533845f, 2.8284270763f, -2.8284270763f, 1.0000000000f}; +#define UpdateAcc() \ + if (bias != 0) acc += bias[co_slice]; \ + if (act_type == ActivationType_RELU) { \ + acc = max(acc, (FLT4)(0.0f)); \ + } else if (act_type == ActivationType_RELU6) { \ + acc = clamp(acc, (FLT4)(0.0f), (FLT4)(6.0f)); \ + } else if (act_type == ActivationType_TANH) { \ + FLT4 exp0 = exp(acc); \ + FLT4 exp1 = exp(-acc); \ + acc = (exp0 - exp1) / (exp0 + exp1); \ + } else if (act_type == ActivationType_LEAKY_RELU) { \ + DO_LEAKY_RELU(acc, alpha); \ + } else if (act_type == ActivationType_SIGMOID) { \ + acc = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-acc)); \ + } + __kernel void Winograd36To4x4(__read_only image2d_t input, // height=CO_SLICES*36 width=TILE_HW __write_only image2d_t output, // height=N*H width=W*CO_SLICES __global FLT4 *bias, @@ -198,11 +264,49 @@ __kernel void Winograd36To4x4(__read_only image2d_t input, // height=CO_SLICE constant FLT *At_row = At + row * 6; FLT4 AtM_row[6] = {0}; - for (int y = 0, idx = co_slice * 36; y < 6; y++) { - for (int x = 0; x < 6; x++, idx++) { - AtM_row[x] += At_row[y] * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx)); - } - } + int idx = co_slice * 36; + FLT at = At_row[0]; + AtM_row[0] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[1] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[2] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[3] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[4] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[5] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + at = At_row[1]; + AtM_row[0] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[1] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[2] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[3] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[4] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[5] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + at = At_row[2]; + AtM_row[0] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[1] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[2] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[3] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[4] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[5] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + at = At_row[3]; + AtM_row[0] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[1] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[2] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[3] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[4] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[5] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + at = At_row[4]; + AtM_row[0] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[1] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[2] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[3] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[4] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[5] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + at = At_row[5]; + AtM_row[0] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[1] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[2] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[3] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[4] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); + AtM_row[5] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++)); int TILE_W = UP_DIV(W, 4); int tile_w = tile_hw % TILE_W; @@ -210,30 +314,24 @@ __kernel void Winograd36To4x4(__read_only image2d_t input, // height=CO_SLICE int h = tile_h * 4 + row; int w = tile_w * 4; int x_idx = w * CO_SLICES + co_slice; - for (int x = 0, idx = 0; x < 4; x++) { - FLT4 acc = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); - for (int y = 0; y < 6; y++, idx++) { - acc += AtM_row[y] * At[idx]; - } - - if (bias != 0) { - acc += bias[co_slice]; - } - - if (act_type == ActivationType_RELU) { - acc = max(acc, (FLT4)(0.0f)); - } else if (act_type == ActivationType_RELU6) { - acc = clamp(acc, (FLT4)(0.0f), (FLT4)(6.0f)); - } else if (act_type == ActivationType_TANH) { - FLT4 exp0 = exp(acc); - FLT4 exp1 = exp(-acc); - acc = (exp0 - exp1) / (exp0 + exp1); - } else if (act_type == ActivationType_LEAKY_RELU) { - DO_LEAKY_RELU(acc, alpha); - } else if (act_type == ActivationType_SIGMOID) { - acc = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-acc)); - } - WRITE_IMAGE(output, (int2)(x_idx, h), acc); - x_idx += CO_SLICES; - } + + FLT4 acc = AtM_row[0] + AtM_row[1] + AtM_row[2] + AtM_row[3] + AtM_row[4]; + UpdateAcc(); + WRITE_IMAGE(output, (int2)(x_idx, h), acc); + x_idx += CO_SLICES; + + acc = (FLT)(0.7071067691f) * (AtM_row[1] - AtM_row[2]) + (FLT)(1.4142135382f) * (AtM_row[3] - AtM_row[4]); + UpdateAcc(); + WRITE_IMAGE(output, (int2)(x_idx, h), acc); + x_idx += CO_SLICES; + + acc = (FLT)(0.5f) * (AtM_row[1] + AtM_row[2]) + (FLT)(2.0f) * (AtM_row[3] + AtM_row[4]); + UpdateAcc(); + WRITE_IMAGE(output, (int2)(x_idx, h), acc); + x_idx += CO_SLICES; + + acc = + (FLT)(0.3535533845f) * (AtM_row[1] - AtM_row[2]) + (FLT)(2.8284270763f) * (AtM_row[3] - AtM_row[4]) + AtM_row[5]; + UpdateAcc(); + WRITE_IMAGE(output, (int2)(x_idx, h), acc); } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc index 8eef95289a..aa745b6613 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc @@ -144,7 +144,9 @@ void Conv2DOpenCLKernel::BuildKernel() { kernel_name << "_Img"; } ocl_runtime_->LoadSource(program_name, GetActDefines() + conv2d_source); - ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name.str()); + std::string build_option = + (OW_ * CO_SLICES_ <= ocl_runtime_->GetMaxImage2DWidth()) ? "" : " -DEXCEDD_MAX_IMAGE2D_WIDTH"; + ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name.str(), {build_option}); } void Conv2DOpenCLKernel::SetBlockSize() { @@ -436,13 +438,6 @@ OpParameter *CreateFcParam(const ConvParameter *conv_param, const std::vector