From: @wangdongxu6 Reviewed-by: @ddwsky,@zhanghaibo5 Signed-off-by: @ddwskypull/15013/MERGE
| @@ -91,11 +91,11 @@ __kernel void Conv2D_H1W1C1(__read_only image2d_t input, __write_only image2d_t | |||||
| out_h0_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w0_c0)); | out_h0_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w0_c0)); | ||||
| } | } | ||||
| #ifndef EXCEDD_MAX_IMAGE2D_WIDTH | |||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); | |||||
| #else | |||||
| WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); | |||||
| #endif | |||||
| if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) { | |||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); | |||||
| } else { | |||||
| WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); | |||||
| } | |||||
| } | } | ||||
| __kernel void Conv2D_H2W1C1(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight, | __kernel void Conv2D_H2W1C1(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight, | ||||
| @@ -172,17 +172,17 @@ __kernel void Conv2D_H2W1C1(__read_only image2d_t input, __write_only image2d_t | |||||
| out_h1_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c0)); | out_h1_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c0)); | ||||
| } | } | ||||
| #ifndef EXCEDD_MAX_IMAGE2D_WIDTH | |||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); | |||||
| if (oh1 < OH) { | |||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0); | |||||
| } // end if (oh1 < OH) | |||||
| #else | |||||
| WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); | |||||
| if (oh1 < OH) { | |||||
| WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0); | |||||
| } // end (oh1 < OH) | |||||
| #endif | |||||
| if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) { | |||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); | |||||
| if (oh1 < OH) { | |||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0); | |||||
| } // end if (oh1 < OH) | |||||
| } else { | |||||
| WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); | |||||
| if (oh1 < OH) { | |||||
| WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0); | |||||
| } // end (oh1 < OH) | |||||
| } | |||||
| } | } | ||||
| __kernel void Conv2D_H2W1C2(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight, | __kernel void Conv2D_H2W1C2(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight, | ||||
| @@ -283,27 +283,29 @@ __kernel void Conv2D_H2W1C2(__read_only image2d_t input, __write_only image2d_t | |||||
| out_h1_w0_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c1)); | out_h1_w0_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c1)); | ||||
| } | } | ||||
| #ifndef EXCEDD_MAX_IMAGE2D_WIDTH | |||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); | |||||
| if (oh1 < OH) { | |||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0); | |||||
| } // end if (oh1 < OH) | |||||
| if (co_slice1 < CO_SLICES) { | |||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1); | |||||
| if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) { | |||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); | |||||
| if (oh1 < OH) { | if (oh1 < OH) { | ||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1); | |||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0); | |||||
| } // end if (oh1 < OH) | } // end if (oh1 < OH) | ||||
| } // end if (co_slice1 < CO_SLICES) | |||||
| #else | |||||
| WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); | |||||
| if (oh1 < OH) { | |||||
| WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0); | |||||
| } // end (oh1 < OH) | |||||
| WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1); | |||||
| if (oh1 < OH) { | |||||
| WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1); | |||||
| } // end if (oh1 < OH) | |||||
| #endif | |||||
| if (co_slice1 < CO_SLICES) { | |||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1); | |||||
| if (oh1 < OH) { | |||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1); | |||||
| } // end if (oh1 < OH) | |||||
| } // end if (co_slice1 < CO_SLICES) | |||||
| } else { | |||||
| WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); | |||||
| if (oh1 < OH) { | |||||
| WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0); | |||||
| } // end (oh1 < OH) | |||||
| if (co_slice1 < CO_SLICES) { | |||||
| WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1); | |||||
| if (oh1 < OH) { | |||||
| WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1); | |||||
| } // end if (oh1 < OH) | |||||
| } // end if (co_slice1 < CO_SLICES) | |||||
| } | |||||
| } | } | ||||
| __kernel void Conv2D_H2W2C2(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight, | __kernel void Conv2D_H2W2C2(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight, | ||||
| @@ -454,35 +456,37 @@ __kernel void Conv2D_H2W2C2(__read_only image2d_t input, __write_only image2d_t | |||||
| out_h1_w1_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w1_c1)); | out_h1_w1_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w1_c1)); | ||||
| } | } | ||||
| #ifndef EXCEDD_MAX_IMAGE2D_WIDTH | |||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); | |||||
| WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh0), out_h0_w1_c0); | |||||
| if (oh1 < OH) { | |||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0); | |||||
| WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh1), out_h1_w1_c0); | |||||
| } // end if (oh1 < OH) | |||||
| if (co_slice1 < CO_SLICES) { | |||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1); | |||||
| WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh0), out_h0_w1_c1); | |||||
| if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) { | |||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); | |||||
| WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh0), out_h0_w1_c0); | |||||
| if (oh1 < OH) { | if (oh1 < OH) { | ||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1); | |||||
| WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh1), out_h1_w1_c1); | |||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0); | |||||
| WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh1), out_h1_w1_c0); | |||||
| } // end if (oh1 < OH) | } // end if (oh1 < OH) | ||||
| } // end if (co_slice1 < CO_SLICES) | |||||
| #else | |||||
| WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); | |||||
| WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow1), out_h0_w1_c0); | |||||
| if (oh1 < OH) { | |||||
| WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0); | |||||
| WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow1), out_h1_w1_c0); | |||||
| } // end (oh1 < OH) | |||||
| WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1); | |||||
| WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow1), out_h0_w1_c1); | |||||
| if (oh1 < OH) { | |||||
| WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1); | |||||
| WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow1), out_h1_w1_c1); | |||||
| } // end if (oh1 < OH) | |||||
| #endif | |||||
| if (co_slice1 < CO_SLICES) { | |||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1); | |||||
| WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh0), out_h0_w1_c1); | |||||
| if (oh1 < OH) { | |||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1); | |||||
| WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh1), out_h1_w1_c1); | |||||
| } // end if (oh1 < OH) | |||||
| } // end if (co_slice1 < CO_SLICES) | |||||
| } else { | |||||
| WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); | |||||
| WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow1), out_h0_w1_c0); | |||||
| if (oh1 < OH) { | |||||
| WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0); | |||||
| WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow1), out_h1_w1_c0); | |||||
| } // end (oh1 < OH) | |||||
| if (co_slice1 < CO_SLICES) { | |||||
| WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1); | |||||
| WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow1), out_h0_w1_c1); | |||||
| if (oh1 < OH) { | |||||
| WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1); | |||||
| WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow1), out_h1_w1_c1); | |||||
| } // end if (oh1 < OH) | |||||
| } // end if (co_slice1 < CO_SLICES) | |||||
| } | |||||
| } | } | ||||
| __kernel void Conv2D_H2W2C2_Img(__read_only image2d_t input, __write_only image2d_t output, | __kernel void Conv2D_H2W2C2_Img(__read_only image2d_t input, __write_only image2d_t output, | ||||
| @@ -640,33 +644,35 @@ __kernel void Conv2D_H2W2C2_Img(__read_only image2d_t input, __write_only image2 | |||||
| out_h1_w1_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w1_c1)); | out_h1_w1_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w1_c1)); | ||||
| } | } | ||||
| #ifndef EXCEDD_MAX_IMAGE2D_WIDTH | |||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); | |||||
| WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh0), out_h0_w1_c0); | |||||
| if (oh1 < OH) { | |||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0); | |||||
| WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh1), out_h1_w1_c0); | |||||
| } // end if (oh1 < OH) | |||||
| if (co_slice1 < CO_SLICES) { | |||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1); | |||||
| WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh0), out_h0_w1_c1); | |||||
| if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) { | |||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); | |||||
| WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh0), out_h0_w1_c0); | |||||
| if (oh1 < OH) { | if (oh1 < OH) { | ||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1); | |||||
| WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh1), out_h1_w1_c1); | |||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0); | |||||
| WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh1), out_h1_w1_c0); | |||||
| } // end if (oh1 < OH) | } // end if (oh1 < OH) | ||||
| } // end if (co_slice1 < CO_SLICES) | |||||
| #else | |||||
| WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); | |||||
| WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow1), out_h0_w1_c0); | |||||
| if (oh1 < OH) { | |||||
| WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0); | |||||
| WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow1), out_h1_w1_c0); | |||||
| } // end (oh1 < OH) | |||||
| WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1); | |||||
| WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow1), out_h0_w1_c1); | |||||
| if (oh1 < OH) { | |||||
| WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1); | |||||
| WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow1), out_h1_w1_c1); | |||||
| } // end if (oh1 < OH) | |||||
| #endif | |||||
| if (co_slice1 < CO_SLICES) { | |||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1); | |||||
| WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh0), out_h0_w1_c1); | |||||
| if (oh1 < OH) { | |||||
| WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1); | |||||
| WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh1), out_h1_w1_c1); | |||||
| } // end if (oh1 < OH) | |||||
| } // end if (co_slice1 < CO_SLICES) | |||||
| } else { | |||||
| WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); | |||||
| WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow1), out_h0_w1_c0); | |||||
| if (oh1 < OH) { | |||||
| WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0); | |||||
| WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow1), out_h1_w1_c0); | |||||
| } // end (oh1 < OH) | |||||
| if (co_slice1 < CO_SLICES) { | |||||
| WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1); | |||||
| WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow1), out_h0_w1_c1); | |||||
| if (oh1 < OH) { | |||||
| WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1); | |||||
| WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow1), out_h1_w1_c1); | |||||
| } // end if (oh1 < OH) | |||||
| } // end if (co_slice1 < CO_SLICES) | |||||
| } | |||||
| } | } | ||||
| @@ -144,9 +144,7 @@ void Conv2DOpenCLKernel::BuildKernel() { | |||||
| kernel_name << "_Img"; | kernel_name << "_Img"; | ||||
| } | } | ||||
| ocl_runtime_->LoadSource(program_name, GetActDefines() + conv2d_source); | ocl_runtime_->LoadSource(program_name, GetActDefines() + conv2d_source); | ||||
| std::string build_option = | |||||
| (OW_ * CO_SLICES_ <= ocl_runtime_->GetMaxImage2DWidth()) ? "" : " -DEXCEDD_MAX_IMAGE2D_WIDTH"; | |||||
| ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name.str(), {build_option}); | |||||
| ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name.str()); | |||||
| } | } | ||||
| void Conv2DOpenCLKernel::SetBlockSize() { | void Conv2DOpenCLKernel::SetBlockSize() { | ||||
| @@ -379,7 +379,7 @@ std::string FusionEltwiseOpenCLKernel::CodegenCore(FusionEltwiseParameter *param | |||||
| code << cl_prefix << "FLT4 " << exp0 << " = exp(" + var0 + ");\n"; | code << cl_prefix << "FLT4 " << exp0 << " = exp(" + var0 + ");\n"; | ||||
| code << cl_prefix << "FLT4 " << exp1 << " = exp(-" + var0 + ");\n"; | code << cl_prefix << "FLT4 " << exp1 << " = exp(-" + var0 + ");\n"; | ||||
| code << cl_prefix << "FLT4 " << out_name << " = (" << exp0 << " - " << exp1 << ") / (" << exp0 << " + " << exp1 | code << cl_prefix << "FLT4 " << out_name << " = (" << exp0 << " - " << exp1 << ") / (" << exp0 << " + " << exp1 | ||||
| << "));\n"; | |||||
| << ");\n"; | |||||
| } | } | ||||
| } | } | ||||
| @@ -143,11 +143,11 @@ int StridedSliceOpenCLKernel::InitConstArgs() { | |||||
| // avoid begin is out of range | // avoid begin is out of range | ||||
| begin_.s[i] = std::clamp(begin_.s[i], 0, input_shape_.s[i] - 1); | begin_.s[i] = std::clamp(begin_.s[i], 0, input_shape_.s[i] - 1); | ||||
| // end is negative | // end is negative | ||||
| if (end_.s[i] < 0) { | |||||
| if (end_.s[i] <= 0) { | |||||
| end_.s[i] += input_shape_.s[i]; | end_.s[i] += input_shape_.s[i]; | ||||
| } | } | ||||
| // avoid end is out of range | // avoid end is out of range | ||||
| end_.s[i] = std::clamp(end_.s[i], -1, input_shape_.s[i]); | |||||
| end_.s[i] = std::clamp(end_.s[i], 0, input_shape_.s[i]); | |||||
| // check stride begin end | // check stride begin end | ||||
| if (stride_.s[i] > 0) { | if (stride_.s[i] > 0) { | ||||
| @@ -42,8 +42,6 @@ int ToFormatOpenCLKernel::CheckSpecs() { | |||||
| MS_LOG(ERROR) << "Unsupported data type " << data_type; | MS_LOG(ERROR) << "Unsupported data type " << data_type; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto parameter = reinterpret_cast<OpenCLToFormatParameter *>(op_parameter_); | |||||
| out_mem_type_ = parameter->out_mem_type; | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -103,8 +101,10 @@ int ToFormatOpenCLKernel::Run() { | |||||
| } | } | ||||
| int ToFormatOpenCLKernel::InferShape() { | int ToFormatOpenCLKernel::InferShape() { | ||||
| out_tensors_[0]->set_shape(in_tensors_[0]->shape()); | |||||
| op_parameter_->infer_flag_ = false; | |||||
| if (!op_parameter_->infer_flag_) { | |||||
| op_parameter_->infer_flag_ = true; | |||||
| out_tensors_.front()->set_shape(in_tensors_.front()->shape()); | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -25,7 +25,11 @@ | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| class ToFormatOpenCLKernel : public OpenCLKernel { | class ToFormatOpenCLKernel : public OpenCLKernel { | ||||
| public: | public: | ||||
| using OpenCLKernel::OpenCLKernel; | |||||
| ToFormatOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) | |||||
| : OpenCLKernel(parameter, inputs, outputs, ctx) { | |||||
| out_mem_type_ = reinterpret_cast<OpenCLToFormatParameter *>(op_parameter_)->out_mem_type; | |||||
| } | |||||
| ~ToFormatOpenCLKernel() override = default; | ~ToFormatOpenCLKernel() override = default; | ||||
| int Run() override; | int Run() override; | ||||
| @@ -135,6 +135,8 @@ std::vector<T *> RemoveDuplicationsButKeepOrder(const std::vector<T *> &vec) { | |||||
| void Merge(LiteKernel *a, LiteKernel *b, bool remove_a) { | void Merge(LiteKernel *a, LiteKernel *b, bool remove_a) { | ||||
| MS_ASSERT(a); | MS_ASSERT(a); | ||||
| MS_ASSERT(b); | MS_ASSERT(b); | ||||
| MS_ASSERT(a->op_parameter()->infer_flag_); | |||||
| MS_ASSERT(b->op_parameter()->infer_flag_); | |||||
| if (remove_a) { // pred->tensor0->a->tensor1->b: remove a tensor1 | if (remove_a) { // pred->tensor0->a->tensor1->b: remove a tensor1 | ||||
| // update pred out_kernels: a.in_kernels.out_kernels.replace(a,b) | // update pred out_kernels: a.in_kernels.out_kernels.replace(a,b) | ||||
| for (auto *pred : a->in_kernels()) { | for (auto *pred : a->in_kernels()) { | ||||
| @@ -231,6 +233,9 @@ void TryMergePadXxx(LiteKernel *node, std::set<LiteKernel *> *removed_set, std:: | |||||
| } | } | ||||
| LiteKernel *pad = node->in_kernels().front(); | LiteKernel *pad = node->in_kernels().front(); | ||||
| MS_ASSERT(pad); | MS_ASSERT(pad); | ||||
| if (!pad->op_parameter()->infer_flag_) { | |||||
| return; | |||||
| } | |||||
| if (pad->in_tensors().front()->shape().size() != 4) { | if (pad->in_tensors().front()->shape().size() != 4) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -264,6 +269,10 @@ void TryMergeConvReshape(LiteKernel *reshape, std::set<LiteKernel *> *removed_se | |||||
| // group must be 1 | // group must be 1 | ||||
| LiteKernel *conv = reshape->in_kernels().front(); | LiteKernel *conv = reshape->in_kernels().front(); | ||||
| MS_ASSERT(conv); | MS_ASSERT(conv); | ||||
| if (!conv->op_parameter()->infer_flag_) { | |||||
| return; | |||||
| } | |||||
| auto *param = reinterpret_cast<ConvParameter *>(reinterpret_cast<OpenCLKernel *>(conv)->GetParameter()); | auto *param = reinterpret_cast<ConvParameter *>(reinterpret_cast<OpenCLKernel *>(conv)->GetParameter()); | ||||
| MS_ASSERT(param); | MS_ASSERT(param); | ||||
| if (param->group_ != 1) { | if (param->group_ != 1) { | ||||
| @@ -286,6 +295,9 @@ void TryMergeFcReshape(LiteKernel *reshape, std::set<LiteKernel *> *removed_set, | |||||
| bool NC_N11C_flag = NC_N11C(reshape); | bool NC_N11C_flag = NC_N11C(reshape); | ||||
| if (NC_N11C_flag || N11C_NC(reshape)) { | if (NC_N11C_flag || N11C_NC(reshape)) { | ||||
| LiteKernel *fc = reshape->in_kernels().front(); | LiteKernel *fc = reshape->in_kernels().front(); | ||||
| if (!fc->op_parameter()->infer_flag_) { | |||||
| return; | |||||
| } | |||||
| MS_ASSERT(fc); | MS_ASSERT(fc); | ||||
| MergeRemoveB(fc, reshape, removed_set); | MergeRemoveB(fc, reshape, removed_set); | ||||
| MS_LOG(DEBUG) << "Merge FullConnection and Reshape" + (NC_N11C_flag ? std::string("(NC->N11C)") : "(N11C->NC)") + | MS_LOG(DEBUG) << "Merge FullConnection and Reshape" + (NC_N11C_flag ? std::string("(NC->N11C)") : "(N11C->NC)") + | ||||
| @@ -303,6 +315,9 @@ void TryMergeReshapeFc(LiteKernel *fc, std::set<LiteKernel *> *removed_set, std: | |||||
| } | } | ||||
| LiteKernel *reshape = fc->in_kernels().front(); | LiteKernel *reshape = fc->in_kernels().front(); | ||||
| MS_ASSERT(reshape); | MS_ASSERT(reshape); | ||||
| if (!reshape->op_parameter()->infer_flag_) { | |||||
| return; | |||||
| } | |||||
| bool NC11_NC_flag = NC11_NC(reshape); | bool NC11_NC_flag = NC11_NC(reshape); | ||||
| if (NC11_NC_flag || NC_N11C(reshape)) { | if (NC11_NC_flag || NC_N11C(reshape)) { | ||||
| MergeRemoveA(reshape, fc, removed_set); | MergeRemoveA(reshape, fc, removed_set); | ||||
| @@ -317,6 +332,9 @@ void TryMergeArithmeticAct(LiteKernel *act, std::set<LiteKernel *> *removed_set) | |||||
| MS_ASSERT(removed_set); | MS_ASSERT(removed_set); | ||||
| LiteKernel *arithmetic = act->in_kernels().front(); | LiteKernel *arithmetic = act->in_kernels().front(); | ||||
| MS_ASSERT(arithmetic); | MS_ASSERT(arithmetic); | ||||
| if (!arithmetic->op_parameter()->infer_flag_) { | |||||
| return; | |||||
| } | |||||
| auto *arithmetic_param = | auto *arithmetic_param = | ||||
| reinterpret_cast<ArithmeticParameter *>(reinterpret_cast<OpenCLKernel *>(arithmetic)->GetParameter()); | reinterpret_cast<ArithmeticParameter *>(reinterpret_cast<OpenCLKernel *>(arithmetic)->GetParameter()); | ||||
| auto *act_param = reinterpret_cast<ActivationParameter *>(reinterpret_cast<OpenCLKernel *>(act)->GetParameter()); | auto *act_param = reinterpret_cast<ActivationParameter *>(reinterpret_cast<OpenCLKernel *>(act)->GetParameter()); | ||||
| @@ -339,6 +357,10 @@ void TryMergeXxxActivation(LiteKernel *act, std::set<LiteKernel *> *removed_set) | |||||
| MS_ASSERT(removed_set); | MS_ASSERT(removed_set); | ||||
| auto *act_param = reinterpret_cast<ActivationParameter *>(reinterpret_cast<OpenCLKernel *>(act)->GetParameter()); | auto *act_param = reinterpret_cast<ActivationParameter *>(reinterpret_cast<OpenCLKernel *>(act)->GetParameter()); | ||||
| LiteKernel *node = act->in_kernels().front(); | LiteKernel *node = act->in_kernels().front(); | ||||
| MS_ASSERT(node); | |||||
| if (!node->op_parameter()->infer_flag_) { | |||||
| return; | |||||
| } | |||||
| auto *param = reinterpret_cast<ParamType *>(reinterpret_cast<OpenCLKernel *>(node)->GetParameter()); | auto *param = reinterpret_cast<ParamType *>(reinterpret_cast<OpenCLKernel *>(node)->GetParameter()); | ||||
| MS_ASSERT(param); | MS_ASSERT(param); | ||||
| @@ -379,6 +401,9 @@ void TryMergeConvPReLU(LiteKernel *prelu, std::set<LiteKernel *> *removed_set, s | |||||
| } | } | ||||
| LiteKernel *conv = prelu->in_kernels().front(); | LiteKernel *conv = prelu->in_kernels().front(); | ||||
| MS_ASSERT(conv); | MS_ASSERT(conv); | ||||
| if (!conv->op_parameter()->infer_flag_) { | |||||
| return; | |||||
| } | |||||
| if (reinterpret_cast<Conv2DOpenCLKernel *>(conv)->use_winograd_) { | if (reinterpret_cast<Conv2DOpenCLKernel *>(conv)->use_winograd_) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -475,6 +500,9 @@ void TryMergeDeconvScale(LiteKernel *scale, std::set<LiteKernel *> *removed_set, | |||||
| } | } | ||||
| LiteKernel *deconv = scale->in_kernels().front(); | LiteKernel *deconv = scale->in_kernels().front(); | ||||
| MS_ASSERT(deconv); | MS_ASSERT(deconv); | ||||
| if (!deconv->op_parameter()->infer_flag_) { | |||||
| return; | |||||
| } | |||||
| // check act_type_ | // check act_type_ | ||||
| auto *deconv_param = reinterpret_cast<ConvParameter *>(reinterpret_cast<OpenCLKernel *>(deconv)->GetParameter()); | auto *deconv_param = reinterpret_cast<ConvParameter *>(reinterpret_cast<OpenCLKernel *>(deconv)->GetParameter()); | ||||
| @@ -546,6 +574,9 @@ void CreateEltwiseKernelReplaceOld(FusionEltwiseParameter *param, LiteKernel *ol | |||||
| // Eltwise + Eltwise | // Eltwise + Eltwise | ||||
| int TryMergeEltwiseEltwise(LiteKernel *node, std::set<LiteKernel *> *removed_set, std::vector<LiteKernel *> *nodes) { | int TryMergeEltwiseEltwise(LiteKernel *node, std::set<LiteKernel *> *removed_set, std::vector<LiteKernel *> *nodes) { | ||||
| if (!node->op_parameter()->infer_flag_) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| MS_ASSERT(node); | MS_ASSERT(node); | ||||
| MS_ASSERT(nodes); | MS_ASSERT(nodes); | ||||
| MS_ASSERT(removed_set); | MS_ASSERT(removed_set); | ||||
| @@ -560,6 +591,9 @@ int TryMergeEltwiseEltwise(LiteKernel *node, std::set<LiteKernel *> *removed_set | |||||
| std::map<lite::Tensor *, FusionEltwiseParameter *> pred_params; | std::map<lite::Tensor *, FusionEltwiseParameter *> pred_params; | ||||
| for (LiteKernel *pred : preds) { | for (LiteKernel *pred : preds) { | ||||
| MS_ASSERT(pred); | MS_ASSERT(pred); | ||||
| if (!pred->op_parameter()->infer_flag_) { | |||||
| continue; | |||||
| } | |||||
| if (AIsInB(pred, nodes) && IsEltwiseAndOperatorSupported(pred) && pred->out_kernels().size() == 1) { | if (AIsInB(pred, nodes) && IsEltwiseAndOperatorSupported(pred) && pred->out_kernels().size() == 1) { | ||||
| auto *tensor = pred->out_tensors().front(); | auto *tensor = pred->out_tensors().front(); | ||||
| MS_ASSERT(pred->out_kernels().front() == node); | MS_ASSERT(pred->out_kernels().front() == node); | ||||
| @@ -589,6 +623,9 @@ int TryMergeEltwiseEltwise(LiteKernel *node, std::set<LiteKernel *> *removed_set | |||||
| } | } | ||||
| void DoSpecificFusion(LiteKernel *node, std::set<LiteKernel *> *removed_set, std::vector<LiteKernel *> *nodes) { | void DoSpecificFusion(LiteKernel *node, std::set<LiteKernel *> *removed_set, std::vector<LiteKernel *> *nodes) { | ||||
| if (!node->op_parameter()->infer_flag_) { | |||||
| return; | |||||
| } | |||||
| switch (node->Type()) { | switch (node->Type()) { | ||||
| case schema::PrimitiveType_Conv2DFusion: | case schema::PrimitiveType_Conv2DFusion: | ||||
| case schema::PrimitiveType_Conv2dTransposeFusion: { | case schema::PrimitiveType_Conv2dTransposeFusion: { | ||||
| @@ -100,6 +100,11 @@ void OpenCLSubGraph::ReplaceOutTensorAndKernelToConvert(const lite::Tensor *in_t | |||||
| iv->set_out_tensors(tensors); | iv->set_out_tensors(tensors); | ||||
| in_convert_op->AddInKernel(iv); | in_convert_op->AddInKernel(iv); | ||||
| } | } | ||||
| if (in_convert_op->in_kernels().empty()) { | |||||
| in_convert_op->op_parameter()->infer_flag_ = true; | |||||
| } else { | |||||
| in_convert_op->op_parameter()->infer_flag_ = in_opencl_op->in_kernels().front()->op_parameter()->infer_flag_; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -143,7 +148,7 @@ int OpenCLSubGraph::GenToFormatOp(const std::vector<lite::Tensor *> &in_tensors, | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| parameter->op_parameter.type_ = PRIM_TO_FORMAT; | parameter->op_parameter.type_ = PRIM_TO_FORMAT; | ||||
| parameter->op_parameter.infer_flag_ = false; | |||||
| parameter->op_parameter.infer_flag_ = false; // infer_flag_ is set in ReplaceOutTensorAndKernelToConvert() | |||||
| parameter->out_mem_type = mem_type; | parameter->out_mem_type = mem_type; | ||||
| out_parameters->emplace_back(parameter); | out_parameters->emplace_back(parameter); | ||||
| LiteKernel *in_convert_op = nullptr; | LiteKernel *in_convert_op = nullptr; | ||||
| @@ -155,7 +160,7 @@ int OpenCLSubGraph::GenToFormatOp(const std::vector<lite::Tensor *> &in_tensors, | |||||
| {new_tensor}, {in_tensor}, reinterpret_cast<OpParameter *>(parameter), context_, desc); | {new_tensor}, {in_tensor}, reinterpret_cast<OpParameter *>(parameter), context_, desc); | ||||
| } | } | ||||
| MS_ASSERT(in_convert_op); | MS_ASSERT(in_convert_op); | ||||
| if (in_convert_op == nullptr) { | |||||
| if (in_convert_op == nullptr || reinterpret_cast<ToFormatOpenCLKernel *>(in_convert_op)->CheckSpecs() != RET_OK) { | |||||
| MS_LOG(ERROR) << "OpenCLSubGraph create op failed!"; | MS_LOG(ERROR) << "OpenCLSubGraph create op failed!"; | ||||
| delete new_tensor; | delete new_tensor; | ||||
| new_tensor = nullptr; | new_tensor = nullptr; | ||||