From aebebe77a49428fe6bc043c4257ae35d183c1489 Mon Sep 17 00:00:00 2001 From: wangdongxu Date: Mon, 12 Apr 2021 18:06:00 +0800 Subject: [PATCH] fix some bugs --- .../src/runtime/kernel/opencl/cl/conv2d.cl | 184 +++++++++--------- .../runtime/kernel/opencl/kernel/conv2d.cc | 4 +- .../kernel/opencl/kernel/fusion_eltwise.cc | 2 +- .../kernel/opencl/kernel/strided_slice.cc | 4 +- .../runtime/kernel/opencl/kernel/to_format.cc | 8 +- .../runtime/kernel/opencl/kernel/to_format.h | 6 +- .../runtime/kernel/opencl/opencl_fusion.cc | 37 ++++ .../runtime/kernel/opencl/opencl_subgraph.cc | 9 +- 8 files changed, 152 insertions(+), 102 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/conv2d.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/conv2d.cl index f31b4045c5..674bc64c91 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/conv2d.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/conv2d.cl @@ -91,11 +91,11 @@ __kernel void Conv2D_H1W1C1(__read_only image2d_t input, __write_only image2d_t out_h0_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w0_c0)); } -#ifndef EXCEDD_MAX_IMAGE2D_WIDTH - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); -#else - WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); -#endif + if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) { + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); + } else { + WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); + } } __kernel void Conv2D_H2W1C1(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight, @@ -172,17 +172,17 @@ __kernel void Conv2D_H2W1C1(__read_only image2d_t input, __write_only image2d_t out_h1_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c0)); } -#ifndef EXCEDD_MAX_IMAGE2D_WIDTH - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); - if (oh1 < OH) { - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0); - } // end if (oh1 < OH) -#else - WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); - if (oh1 < OH) { - WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0); - } // end (oh1 < OH) -#endif + if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) { + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); + if (oh1 < OH) { + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0); + } // end if (oh1 < OH) + } else { + WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); + if (oh1 < OH) { + WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0); + } // end (oh1 < OH) + } } __kernel void Conv2D_H2W1C2(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight, @@ -283,27 +283,29 @@ __kernel void Conv2D_H2W1C2(__read_only image2d_t input, __write_only image2d_t out_h1_w0_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c1)); } -#ifndef EXCEDD_MAX_IMAGE2D_WIDTH - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); - if (oh1 < OH) { - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0); - } // end if (oh1 < OH) - if (co_slice1 < CO_SLICES) { - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1); + if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) { + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); if (oh1 < OH) { - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1); + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0); } // end if (oh1 < OH) - } // end if (co_slice1 < CO_SLICES) -#else - WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); - if (oh1 < OH) { - WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0); - } // end (oh1 < OH) - WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1); - if (oh1 < OH) { - WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1); - } // end if (oh1 < OH) -#endif + if (co_slice1 < CO_SLICES) { + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1); + if (oh1 < OH) { + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1); + } // end if (oh1 < OH) + } // end if (co_slice1 < CO_SLICES) + } else { + WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); + if (oh1 < OH) { + WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0); + } // end (oh1 < OH) + if (co_slice1 < CO_SLICES) { + WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1); + if (oh1 < OH) { + WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1); + } // end if (oh1 < OH) + } // end if (co_slice1 < CO_SLICES) + } } __kernel void Conv2D_H2W2C2(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight, @@ -454,35 +456,37 @@ __kernel void Conv2D_H2W2C2(__read_only image2d_t input, __write_only image2d_t out_h1_w1_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w1_c1)); } -#ifndef EXCEDD_MAX_IMAGE2D_WIDTH - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); - WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh0), out_h0_w1_c0); - if (oh1 < OH) { - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0); - WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh1), out_h1_w1_c0); - } // end if (oh1 < OH) - if (co_slice1 < CO_SLICES) { - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1); - WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh0), out_h0_w1_c1); + if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) { + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); + WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh0), out_h0_w1_c0); if (oh1 < OH) { - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1); - WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh1), out_h1_w1_c1); + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0); + WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh1), out_h1_w1_c0); } // end if (oh1 < OH) - } // end if (co_slice1 < CO_SLICES) -#else - WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); - WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow1), out_h0_w1_c0); - if (oh1 < OH) { - WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0); - WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow1), out_h1_w1_c0); - } // end (oh1 < OH) - WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1); - WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow1), out_h0_w1_c1); - if (oh1 < OH) { - WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1); - WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow1), out_h1_w1_c1); - } // end if (oh1 < OH) -#endif + if (co_slice1 < CO_SLICES) { + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1); + WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh0), out_h0_w1_c1); + if (oh1 < OH) { + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1); + WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh1), out_h1_w1_c1); + } // end if (oh1 < OH) + } // end if (co_slice1 < CO_SLICES) + } else { + WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); + WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow1), out_h0_w1_c0); + if (oh1 < OH) { + WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0); + WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow1), out_h1_w1_c0); + } // end (oh1 < OH) + if (co_slice1 < CO_SLICES) { + WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1); + WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow1), out_h0_w1_c1); + if (oh1 < OH) { + WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1); + WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow1), out_h1_w1_c1); + } // end if (oh1 < OH) + } // end if (co_slice1 < CO_SLICES) + } } __kernel void Conv2D_H2W2C2_Img(__read_only image2d_t input, __write_only image2d_t output, @@ -640,33 +644,35 @@ __kernel void Conv2D_H2W2C2_Img(__read_only image2d_t input, __write_only image2 out_h1_w1_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w1_c1)); } -#ifndef EXCEDD_MAX_IMAGE2D_WIDTH - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); - WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh0), out_h0_w1_c0); - if (oh1 < OH) { - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0); - WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh1), out_h1_w1_c0); - } // end if (oh1 < OH) - if (co_slice1 < CO_SLICES) { - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1); - WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh0), out_h0_w1_c1); + if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) { + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0); + WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh0), out_h0_w1_c0); if (oh1 < OH) { - WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1); - WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh1), out_h1_w1_c1); + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0); + WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh1), out_h1_w1_c0); } // end if (oh1 < OH) - } // end if (co_slice1 < CO_SLICES) -#else - WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); - WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow1), out_h0_w1_c0); - if (oh1 < OH) { - WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0); - WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow1), out_h1_w1_c0); - } // end (oh1 < OH) - WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1); - WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow1), out_h0_w1_c1); - if (oh1 < OH) { - WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1); - WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow1), out_h1_w1_c1); - } // end if (oh1 < OH) -#endif + if (co_slice1 < CO_SLICES) { + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1); + WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh0), out_h0_w1_c1); + if (oh1 < OH) { + WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1); + WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh1), out_h1_w1_c1); + } // end if (oh1 < OH) + } // end if (co_slice1 < CO_SLICES) + } else { + WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0); + WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow1), out_h0_w1_c0); + if (oh1 < OH) { + WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0); + WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow1), out_h1_w1_c0); + } // end (oh1 < OH) + if (co_slice1 < CO_SLICES) { + WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1); + WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow1), out_h0_w1_c1); + if (oh1 < OH) { + WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1); + WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow1), out_h1_w1_c1); + } // end if (oh1 < OH) + } // end if (co_slice1 < CO_SLICES) + } } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc index 779322c179..723f720993 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc @@ -144,9 +144,7 @@ void Conv2DOpenCLKernel::BuildKernel() { kernel_name << "_Img"; } ocl_runtime_->LoadSource(program_name, GetActDefines() + conv2d_source); - std::string build_option = - (OW_ * CO_SLICES_ <= ocl_runtime_->GetMaxImage2DWidth()) ? "" : " -DEXCEDD_MAX_IMAGE2D_WIDTH"; - ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name.str(), {build_option}); + ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name.str()); } void Conv2DOpenCLKernel::SetBlockSize() { diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/fusion_eltwise.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/fusion_eltwise.cc index 57ce27f397..8fe7af3d35 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/fusion_eltwise.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/fusion_eltwise.cc @@ -379,7 +379,7 @@ std::string FusionEltwiseOpenCLKernel::CodegenCore(FusionEltwiseParameter *param code << cl_prefix << "FLT4 " << exp0 << " = exp(" + var0 + ");\n"; code << cl_prefix << "FLT4 " << exp1 << " = exp(-" + var0 + ");\n"; code << cl_prefix << "FLT4 " << out_name << " = (" << exp0 << " - " << exp1 << ") / (" << exp0 << " + " << exp1 - << "));\n"; + << ");\n"; } } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/strided_slice.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/strided_slice.cc index c8d772570c..e05db42848 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/strided_slice.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/strided_slice.cc @@ -143,11 +143,11 @@ int StridedSliceOpenCLKernel::InitConstArgs() { // avoid begin is out of range begin_.s[i] = std::clamp(begin_.s[i], 0, input_shape_.s[i] - 1); // end is negative - if (end_.s[i] < 0) { + if (end_.s[i] <= 0) { end_.s[i] += input_shape_.s[i]; } // avoid end is out of range - end_.s[i] = std::clamp(end_.s[i], -1, input_shape_.s[i]); + end_.s[i] = std::clamp(end_.s[i], 0, input_shape_.s[i]); // check stride begin end if (stride_.s[i] > 0) { diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc index dc3d53dcb1..1ae4303451 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc @@ -42,8 +42,6 @@ int ToFormatOpenCLKernel::CheckSpecs() { MS_LOG(ERROR) << "Unsupported data type " << data_type; return RET_ERROR; } - auto parameter = reinterpret_cast(op_parameter_); - out_mem_type_ = parameter->out_mem_type; return RET_OK; } @@ -103,8 +101,10 @@ int ToFormatOpenCLKernel::Run() { } int ToFormatOpenCLKernel::InferShape() { - out_tensors_[0]->set_shape(in_tensors_[0]->shape()); - op_parameter_->infer_flag_ = false; + if (!op_parameter_->infer_flag_) { + op_parameter_->infer_flag_ = true; + out_tensors_.front()->set_shape(in_tensors_.front()->shape()); + } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.h index a84f10a363..c43ce8214f 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.h @@ -25,7 +25,11 @@ namespace mindspore::kernel { class ToFormatOpenCLKernel : public OpenCLKernel { public: - using OpenCLKernel::OpenCLKernel; + ToFormatOpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx) + : OpenCLKernel(parameter, inputs, outputs, ctx) { + out_mem_type_ = reinterpret_cast(op_parameter_)->out_mem_type; + } ~ToFormatOpenCLKernel() override = default; int Run() override; diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc b/mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc index 5a97caf666..2079b65bc4 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc @@ -135,6 +135,8 @@ std::vector RemoveDuplicationsButKeepOrder(const std::vector &vec) { void Merge(LiteKernel *a, LiteKernel *b, bool remove_a) { MS_ASSERT(a); MS_ASSERT(b); + MS_ASSERT(a->op_parameter()->infer_flag_); + MS_ASSERT(b->op_parameter()->infer_flag_); if (remove_a) { // pred->tensor0->a->tensor1->b: remove a tensor1 // update pred out_kernels: a.in_kernels.out_kernels.replace(a,b) for (auto *pred : a->in_kernels()) { @@ -231,6 +233,9 @@ void TryMergePadXxx(LiteKernel *node, std::set *removed_set, std:: } LiteKernel *pad = node->in_kernels().front(); MS_ASSERT(pad); + if (!pad->op_parameter()->infer_flag_) { + return; + } if (pad->in_tensors().front()->shape().size() != 4) { return; } @@ -264,6 +269,10 @@ void TryMergeConvReshape(LiteKernel *reshape, std::set *removed_se // group must be 1 LiteKernel *conv = reshape->in_kernels().front(); MS_ASSERT(conv); + + if (!conv->op_parameter()->infer_flag_) { + return; + } auto *param = reinterpret_cast(reinterpret_cast(conv)->GetParameter()); MS_ASSERT(param); if (param->group_ != 1) { @@ -286,6 +295,9 @@ void TryMergeFcReshape(LiteKernel *reshape, std::set *removed_set, bool NC_N11C_flag = NC_N11C(reshape); if (NC_N11C_flag || N11C_NC(reshape)) { LiteKernel *fc = reshape->in_kernels().front(); + if (!fc->op_parameter()->infer_flag_) { + return; + } MS_ASSERT(fc); MergeRemoveB(fc, reshape, removed_set); MS_LOG(DEBUG) << "Merge FullConnection and Reshape" + (NC_N11C_flag ? std::string("(NC->N11C)") : "(N11C->NC)") + @@ -303,6 +315,9 @@ void TryMergeReshapeFc(LiteKernel *fc, std::set *removed_set, std: } LiteKernel *reshape = fc->in_kernels().front(); MS_ASSERT(reshape); + if (!reshape->op_parameter()->infer_flag_) { + return; + } bool NC11_NC_flag = NC11_NC(reshape); if (NC11_NC_flag || NC_N11C(reshape)) { MergeRemoveA(reshape, fc, removed_set); @@ -317,6 +332,9 @@ void TryMergeArithmeticAct(LiteKernel *act, std::set *removed_set) MS_ASSERT(removed_set); LiteKernel *arithmetic = act->in_kernels().front(); MS_ASSERT(arithmetic); + if (!arithmetic->op_parameter()->infer_flag_) { + return; + } auto *arithmetic_param = reinterpret_cast(reinterpret_cast(arithmetic)->GetParameter()); auto *act_param = reinterpret_cast(reinterpret_cast(act)->GetParameter()); @@ -339,6 +357,10 @@ void TryMergeXxxActivation(LiteKernel *act, std::set *removed_set) MS_ASSERT(removed_set); auto *act_param = reinterpret_cast(reinterpret_cast(act)->GetParameter()); LiteKernel *node = act->in_kernels().front(); + MS_ASSERT(node); + if (!node->op_parameter()->infer_flag_) { + return; + } auto *param = reinterpret_cast(reinterpret_cast(node)->GetParameter()); MS_ASSERT(param); @@ -379,6 +401,9 @@ void TryMergeConvPReLU(LiteKernel *prelu, std::set *removed_set, s } LiteKernel *conv = prelu->in_kernels().front(); MS_ASSERT(conv); + if (!conv->op_parameter()->infer_flag_) { + return; + } if (reinterpret_cast(conv)->use_winograd_) { return; } @@ -475,6 +500,9 @@ void TryMergeDeconvScale(LiteKernel *scale, std::set *removed_set, } LiteKernel *deconv = scale->in_kernels().front(); MS_ASSERT(deconv); + if (!deconv->op_parameter()->infer_flag_) { + return; + } // check act_type_ auto *deconv_param = reinterpret_cast(reinterpret_cast(deconv)->GetParameter()); @@ -546,6 +574,9 @@ void CreateEltwiseKernelReplaceOld(FusionEltwiseParameter *param, LiteKernel *ol // Eltwise + Eltwise int TryMergeEltwiseEltwise(LiteKernel *node, std::set *removed_set, std::vector *nodes) { + if (!node->op_parameter()->infer_flag_) { + return RET_ERROR; + } MS_ASSERT(node); MS_ASSERT(nodes); MS_ASSERT(removed_set); @@ -560,6 +591,9 @@ int TryMergeEltwiseEltwise(LiteKernel *node, std::set *removed_set std::map pred_params; for (LiteKernel *pred : preds) { MS_ASSERT(pred); + if (!pred->op_parameter()->infer_flag_) { + continue; + } if (AIsInB(pred, nodes) && IsEltwiseAndOperatorSupported(pred) && pred->out_kernels().size() == 1) { auto *tensor = pred->out_tensors().front(); MS_ASSERT(pred->out_kernels().front() == node); @@ -589,6 +623,9 @@ int TryMergeEltwiseEltwise(LiteKernel *node, std::set *removed_set } void DoSpecificFusion(LiteKernel *node, std::set *removed_set, std::vector *nodes) { + if (!node->op_parameter()->infer_flag_) { + return; + } switch (node->Type()) { case schema::PrimitiveType_Conv2DFusion: case schema::PrimitiveType_Conv2dTransposeFusion: { diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc b/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc index 04bbe7d922..0d9d78b046 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc @@ -100,6 +100,11 @@ void OpenCLSubGraph::ReplaceOutTensorAndKernelToConvert(const lite::Tensor *in_t iv->set_out_tensors(tensors); in_convert_op->AddInKernel(iv); } + if (in_convert_op->in_kernels().empty()) { + in_convert_op->op_parameter()->infer_flag_ = true; + } else { + in_convert_op->op_parameter()->infer_flag_ = in_opencl_op->in_kernels().front()->op_parameter()->infer_flag_; + } } } @@ -143,7 +148,7 @@ int OpenCLSubGraph::GenToFormatOp(const std::vector &in_tensors, return RET_ERROR; } parameter->op_parameter.type_ = PRIM_TO_FORMAT; - parameter->op_parameter.infer_flag_ = false; + parameter->op_parameter.infer_flag_ = false; // infer_flag_ is set in ReplaceOutTensorAndKernelToConvert() parameter->out_mem_type = mem_type; out_parameters->emplace_back(parameter); LiteKernel *in_convert_op = nullptr; @@ -155,7 +160,7 @@ int OpenCLSubGraph::GenToFormatOp(const std::vector &in_tensors, {new_tensor}, {in_tensor}, reinterpret_cast(parameter), context_, desc); } MS_ASSERT(in_convert_op); - if (in_convert_op == nullptr) { + if (in_convert_op == nullptr || reinterpret_cast(in_convert_op)->CheckSpecs() != RET_OK) { MS_LOG(ERROR) << "OpenCLSubGraph create op failed!"; delete new_tensor; new_tensor = nullptr;