Browse Source

!15013 【MS】【LITE】【GPU】fix some opencl bugs

From: @wangdongxu6
Reviewed-by: @ddwsky,@zhanghaibo5
Signed-off-by: @ddwsky
pull/15013/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
fa20f96b32
8 changed files with 152 additions and 102 deletions
  1. +95
    -89
      mindspore/lite/src/runtime/kernel/opencl/cl/conv2d.cl
  2. +1
    -3
      mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc
  3. +1
    -1
      mindspore/lite/src/runtime/kernel/opencl/kernel/fusion_eltwise.cc
  4. +2
    -2
      mindspore/lite/src/runtime/kernel/opencl/kernel/strided_slice.cc
  5. +4
    -4
      mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc
  6. +5
    -1
      mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.h
  7. +37
    -0
      mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc
  8. +7
    -2
      mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc

+ 95
- 89
mindspore/lite/src/runtime/kernel/opencl/cl/conv2d.cl View File

@@ -91,11 +91,11 @@ __kernel void Conv2D_H1W1C1(__read_only image2d_t input, __write_only image2d_t
out_h0_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w0_c0)); out_h0_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w0_c0));
} }


#ifndef EXCEDD_MAX_IMAGE2D_WIDTH
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);
#else
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);
#endif
if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);
} else {
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);
}
} }


__kernel void Conv2D_H2W1C1(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight, __kernel void Conv2D_H2W1C1(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight,
@@ -172,17 +172,17 @@ __kernel void Conv2D_H2W1C1(__read_only image2d_t input, __write_only image2d_t
out_h1_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c0)); out_h1_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c0));
} }


#ifndef EXCEDD_MAX_IMAGE2D_WIDTH
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0);
} // end if (oh1 < OH)
#else
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);
} // end (oh1 < OH)
#endif
if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0);
} // end if (oh1 < OH)
} else {
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);
} // end (oh1 < OH)
}
} }


__kernel void Conv2D_H2W1C2(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight, __kernel void Conv2D_H2W1C2(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight,
@@ -283,27 +283,29 @@ __kernel void Conv2D_H2W1C2(__read_only image2d_t input, __write_only image2d_t
out_h1_w0_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c1)); out_h1_w0_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c1));
} }


#ifndef EXCEDD_MAX_IMAGE2D_WIDTH
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0);
} // end if (oh1 < OH)
if (co_slice1 < CO_SLICES) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1);
if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);
if (oh1 < OH) { if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1);
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0);
} // end if (oh1 < OH) } // end if (oh1 < OH)
} // end if (co_slice1 < CO_SLICES)
#else
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);
} // end (oh1 < OH)
WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1);
} // end if (oh1 < OH)
#endif
if (co_slice1 < CO_SLICES) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1);
} // end if (oh1 < OH)
} // end if (co_slice1 < CO_SLICES)
} else {
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);
} // end (oh1 < OH)
if (co_slice1 < CO_SLICES) {
WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1);
} // end if (oh1 < OH)
} // end if (co_slice1 < CO_SLICES)
}
} }


__kernel void Conv2D_H2W2C2(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight, __kernel void Conv2D_H2W2C2(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight,
@@ -454,35 +456,37 @@ __kernel void Conv2D_H2W2C2(__read_only image2d_t input, __write_only image2d_t
out_h1_w1_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w1_c1)); out_h1_w1_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w1_c1));
} }


#ifndef EXCEDD_MAX_IMAGE2D_WIDTH
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh0), out_h0_w1_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh1), out_h1_w1_c0);
} // end if (oh1 < OH)
if (co_slice1 < CO_SLICES) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh0), out_h0_w1_c1);
if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh0), out_h0_w1_c0);
if (oh1 < OH) { if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh1), out_h1_w1_c1);
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh1), out_h1_w1_c0);
} // end if (oh1 < OH) } // end if (oh1 < OH)
} // end if (co_slice1 < CO_SLICES)
#else
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow1), out_h0_w1_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow1), out_h1_w1_c0);
} // end (oh1 < OH)
WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1);
WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow1), out_h0_w1_c1);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1);
WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow1), out_h1_w1_c1);
} // end if (oh1 < OH)
#endif
if (co_slice1 < CO_SLICES) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh0), out_h0_w1_c1);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh1), out_h1_w1_c1);
} // end if (oh1 < OH)
} // end if (co_slice1 < CO_SLICES)
} else {
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow1), out_h0_w1_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow1), out_h1_w1_c0);
} // end (oh1 < OH)
if (co_slice1 < CO_SLICES) {
WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1);
WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow1), out_h0_w1_c1);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1);
WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow1), out_h1_w1_c1);
} // end if (oh1 < OH)
} // end if (co_slice1 < CO_SLICES)
}
} }


__kernel void Conv2D_H2W2C2_Img(__read_only image2d_t input, __write_only image2d_t output, __kernel void Conv2D_H2W2C2_Img(__read_only image2d_t input, __write_only image2d_t output,
@@ -640,33 +644,35 @@ __kernel void Conv2D_H2W2C2_Img(__read_only image2d_t input, __write_only image2
out_h1_w1_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w1_c1)); out_h1_w1_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w1_c1));
} }


#ifndef EXCEDD_MAX_IMAGE2D_WIDTH
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh0), out_h0_w1_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh1), out_h1_w1_c0);
} // end if (oh1 < OH)
if (co_slice1 < CO_SLICES) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh0), out_h0_w1_c1);
if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh0), out_h0_w1_c0);
if (oh1 < OH) { if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh1), out_h1_w1_c1);
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh1), out_h1_w1_c0);
} // end if (oh1 < OH) } // end if (oh1 < OH)
} // end if (co_slice1 < CO_SLICES)
#else
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow1), out_h0_w1_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow1), out_h1_w1_c0);
} // end (oh1 < OH)
WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1);
WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow1), out_h0_w1_c1);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1);
WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow1), out_h1_w1_c1);
} // end if (oh1 < OH)
#endif
if (co_slice1 < CO_SLICES) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh0), out_h0_w1_c1);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh1), out_h1_w1_c1);
} // end if (oh1 < OH)
} // end if (co_slice1 < CO_SLICES)
} else {
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow1), out_h0_w1_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow1), out_h1_w1_c0);
} // end (oh1 < OH)
if (co_slice1 < CO_SLICES) {
WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1);
WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow1), out_h0_w1_c1);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1);
WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow1), out_h1_w1_c1);
} // end if (oh1 < OH)
} // end if (co_slice1 < CO_SLICES)
}
} }

+ 1
- 3
mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc View File

@@ -144,9 +144,7 @@ void Conv2DOpenCLKernel::BuildKernel() {
kernel_name << "_Img"; kernel_name << "_Img";
} }
ocl_runtime_->LoadSource(program_name, GetActDefines() + conv2d_source); ocl_runtime_->LoadSource(program_name, GetActDefines() + conv2d_source);
std::string build_option =
(OW_ * CO_SLICES_ <= ocl_runtime_->GetMaxImage2DWidth()) ? "" : " -DEXCEDD_MAX_IMAGE2D_WIDTH";
ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name.str(), {build_option});
ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name.str());
} }


void Conv2DOpenCLKernel::SetBlockSize() { void Conv2DOpenCLKernel::SetBlockSize() {


+ 1
- 1
mindspore/lite/src/runtime/kernel/opencl/kernel/fusion_eltwise.cc View File

@@ -379,7 +379,7 @@ std::string FusionEltwiseOpenCLKernel::CodegenCore(FusionEltwiseParameter *param
code << cl_prefix << "FLT4 " << exp0 << " = exp(" + var0 + ");\n"; code << cl_prefix << "FLT4 " << exp0 << " = exp(" + var0 + ");\n";
code << cl_prefix << "FLT4 " << exp1 << " = exp(-" + var0 + ");\n"; code << cl_prefix << "FLT4 " << exp1 << " = exp(-" + var0 + ");\n";
code << cl_prefix << "FLT4 " << out_name << " = (" << exp0 << " - " << exp1 << ") / (" << exp0 << " + " << exp1 code << cl_prefix << "FLT4 " << out_name << " = (" << exp0 << " - " << exp1 << ") / (" << exp0 << " + " << exp1
<< "));\n";
<< ");\n";
} }
} }




+ 2
- 2
mindspore/lite/src/runtime/kernel/opencl/kernel/strided_slice.cc View File

@@ -143,11 +143,11 @@ int StridedSliceOpenCLKernel::InitConstArgs() {
// avoid begin is out of range // avoid begin is out of range
begin_.s[i] = std::clamp(begin_.s[i], 0, input_shape_.s[i] - 1); begin_.s[i] = std::clamp(begin_.s[i], 0, input_shape_.s[i] - 1);
// end is negative // end is negative
if (end_.s[i] < 0) {
if (end_.s[i] <= 0) {
end_.s[i] += input_shape_.s[i]; end_.s[i] += input_shape_.s[i];
} }
// avoid end is out of range // avoid end is out of range
end_.s[i] = std::clamp(end_.s[i], -1, input_shape_.s[i]);
end_.s[i] = std::clamp(end_.s[i], 0, input_shape_.s[i]);


// check stride begin end // check stride begin end
if (stride_.s[i] > 0) { if (stride_.s[i] > 0) {


+ 4
- 4
mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc View File

@@ -42,8 +42,6 @@ int ToFormatOpenCLKernel::CheckSpecs() {
MS_LOG(ERROR) << "Unsupported data type " << data_type; MS_LOG(ERROR) << "Unsupported data type " << data_type;
return RET_ERROR; return RET_ERROR;
} }
auto parameter = reinterpret_cast<OpenCLToFormatParameter *>(op_parameter_);
out_mem_type_ = parameter->out_mem_type;
return RET_OK; return RET_OK;
} }


@@ -103,8 +101,10 @@ int ToFormatOpenCLKernel::Run() {
} }


int ToFormatOpenCLKernel::InferShape() { int ToFormatOpenCLKernel::InferShape() {
out_tensors_[0]->set_shape(in_tensors_[0]->shape());
op_parameter_->infer_flag_ = false;
if (!op_parameter_->infer_flag_) {
op_parameter_->infer_flag_ = true;
out_tensors_.front()->set_shape(in_tensors_.front()->shape());
}
return RET_OK; return RET_OK;
} }




+ 5
- 1
mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.h View File

@@ -25,7 +25,11 @@
namespace mindspore::kernel { namespace mindspore::kernel {
class ToFormatOpenCLKernel : public OpenCLKernel { class ToFormatOpenCLKernel : public OpenCLKernel {
public: public:
using OpenCLKernel::OpenCLKernel;
ToFormatOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: OpenCLKernel(parameter, inputs, outputs, ctx) {
out_mem_type_ = reinterpret_cast<OpenCLToFormatParameter *>(op_parameter_)->out_mem_type;
}
~ToFormatOpenCLKernel() override = default; ~ToFormatOpenCLKernel() override = default;


int Run() override; int Run() override;


+ 37
- 0
mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc View File

@@ -135,6 +135,8 @@ std::vector<T *> RemoveDuplicationsButKeepOrder(const std::vector<T *> &vec) {
void Merge(LiteKernel *a, LiteKernel *b, bool remove_a) { void Merge(LiteKernel *a, LiteKernel *b, bool remove_a) {
MS_ASSERT(a); MS_ASSERT(a);
MS_ASSERT(b); MS_ASSERT(b);
MS_ASSERT(a->op_parameter()->infer_flag_);
MS_ASSERT(b->op_parameter()->infer_flag_);
if (remove_a) { // pred->tensor0->a->tensor1->b: remove a tensor1 if (remove_a) { // pred->tensor0->a->tensor1->b: remove a tensor1
// update pred out_kernels: a.in_kernels.out_kernels.replace(a,b) // update pred out_kernels: a.in_kernels.out_kernels.replace(a,b)
for (auto *pred : a->in_kernels()) { for (auto *pred : a->in_kernels()) {
@@ -231,6 +233,9 @@ void TryMergePadXxx(LiteKernel *node, std::set<LiteKernel *> *removed_set, std::
} }
LiteKernel *pad = node->in_kernels().front(); LiteKernel *pad = node->in_kernels().front();
MS_ASSERT(pad); MS_ASSERT(pad);
if (!pad->op_parameter()->infer_flag_) {
return;
}
if (pad->in_tensors().front()->shape().size() != 4) { if (pad->in_tensors().front()->shape().size() != 4) {
return; return;
} }
@@ -264,6 +269,10 @@ void TryMergeConvReshape(LiteKernel *reshape, std::set<LiteKernel *> *removed_se
// group must be 1 // group must be 1
LiteKernel *conv = reshape->in_kernels().front(); LiteKernel *conv = reshape->in_kernels().front();
MS_ASSERT(conv); MS_ASSERT(conv);

if (!conv->op_parameter()->infer_flag_) {
return;
}
auto *param = reinterpret_cast<ConvParameter *>(reinterpret_cast<OpenCLKernel *>(conv)->GetParameter()); auto *param = reinterpret_cast<ConvParameter *>(reinterpret_cast<OpenCLKernel *>(conv)->GetParameter());
MS_ASSERT(param); MS_ASSERT(param);
if (param->group_ != 1) { if (param->group_ != 1) {
@@ -286,6 +295,9 @@ void TryMergeFcReshape(LiteKernel *reshape, std::set<LiteKernel *> *removed_set,
bool NC_N11C_flag = NC_N11C(reshape); bool NC_N11C_flag = NC_N11C(reshape);
if (NC_N11C_flag || N11C_NC(reshape)) { if (NC_N11C_flag || N11C_NC(reshape)) {
LiteKernel *fc = reshape->in_kernels().front(); LiteKernel *fc = reshape->in_kernels().front();
if (!fc->op_parameter()->infer_flag_) {
return;
}
MS_ASSERT(fc); MS_ASSERT(fc);
MergeRemoveB(fc, reshape, removed_set); MergeRemoveB(fc, reshape, removed_set);
MS_LOG(DEBUG) << "Merge FullConnection and Reshape" + (NC_N11C_flag ? std::string("(NC->N11C)") : "(N11C->NC)") + MS_LOG(DEBUG) << "Merge FullConnection and Reshape" + (NC_N11C_flag ? std::string("(NC->N11C)") : "(N11C->NC)") +
@@ -303,6 +315,9 @@ void TryMergeReshapeFc(LiteKernel *fc, std::set<LiteKernel *> *removed_set, std:
} }
LiteKernel *reshape = fc->in_kernels().front(); LiteKernel *reshape = fc->in_kernels().front();
MS_ASSERT(reshape); MS_ASSERT(reshape);
if (!reshape->op_parameter()->infer_flag_) {
return;
}
bool NC11_NC_flag = NC11_NC(reshape); bool NC11_NC_flag = NC11_NC(reshape);
if (NC11_NC_flag || NC_N11C(reshape)) { if (NC11_NC_flag || NC_N11C(reshape)) {
MergeRemoveA(reshape, fc, removed_set); MergeRemoveA(reshape, fc, removed_set);
@@ -317,6 +332,9 @@ void TryMergeArithmeticAct(LiteKernel *act, std::set<LiteKernel *> *removed_set)
MS_ASSERT(removed_set); MS_ASSERT(removed_set);
LiteKernel *arithmetic = act->in_kernels().front(); LiteKernel *arithmetic = act->in_kernels().front();
MS_ASSERT(arithmetic); MS_ASSERT(arithmetic);
if (!arithmetic->op_parameter()->infer_flag_) {
return;
}
auto *arithmetic_param = auto *arithmetic_param =
reinterpret_cast<ArithmeticParameter *>(reinterpret_cast<OpenCLKernel *>(arithmetic)->GetParameter()); reinterpret_cast<ArithmeticParameter *>(reinterpret_cast<OpenCLKernel *>(arithmetic)->GetParameter());
auto *act_param = reinterpret_cast<ActivationParameter *>(reinterpret_cast<OpenCLKernel *>(act)->GetParameter()); auto *act_param = reinterpret_cast<ActivationParameter *>(reinterpret_cast<OpenCLKernel *>(act)->GetParameter());
@@ -339,6 +357,10 @@ void TryMergeXxxActivation(LiteKernel *act, std::set<LiteKernel *> *removed_set)
MS_ASSERT(removed_set); MS_ASSERT(removed_set);
auto *act_param = reinterpret_cast<ActivationParameter *>(reinterpret_cast<OpenCLKernel *>(act)->GetParameter()); auto *act_param = reinterpret_cast<ActivationParameter *>(reinterpret_cast<OpenCLKernel *>(act)->GetParameter());
LiteKernel *node = act->in_kernels().front(); LiteKernel *node = act->in_kernels().front();
MS_ASSERT(node);
if (!node->op_parameter()->infer_flag_) {
return;
}
auto *param = reinterpret_cast<ParamType *>(reinterpret_cast<OpenCLKernel *>(node)->GetParameter()); auto *param = reinterpret_cast<ParamType *>(reinterpret_cast<OpenCLKernel *>(node)->GetParameter());
MS_ASSERT(param); MS_ASSERT(param);


@@ -379,6 +401,9 @@ void TryMergeConvPReLU(LiteKernel *prelu, std::set<LiteKernel *> *removed_set, s
} }
LiteKernel *conv = prelu->in_kernels().front(); LiteKernel *conv = prelu->in_kernels().front();
MS_ASSERT(conv); MS_ASSERT(conv);
if (!conv->op_parameter()->infer_flag_) {
return;
}
if (reinterpret_cast<Conv2DOpenCLKernel *>(conv)->use_winograd_) { if (reinterpret_cast<Conv2DOpenCLKernel *>(conv)->use_winograd_) {
return; return;
} }
@@ -475,6 +500,9 @@ void TryMergeDeconvScale(LiteKernel *scale, std::set<LiteKernel *> *removed_set,
} }
LiteKernel *deconv = scale->in_kernels().front(); LiteKernel *deconv = scale->in_kernels().front();
MS_ASSERT(deconv); MS_ASSERT(deconv);
if (!deconv->op_parameter()->infer_flag_) {
return;
}


// check act_type_ // check act_type_
auto *deconv_param = reinterpret_cast<ConvParameter *>(reinterpret_cast<OpenCLKernel *>(deconv)->GetParameter()); auto *deconv_param = reinterpret_cast<ConvParameter *>(reinterpret_cast<OpenCLKernel *>(deconv)->GetParameter());
@@ -546,6 +574,9 @@ void CreateEltwiseKernelReplaceOld(FusionEltwiseParameter *param, LiteKernel *ol


// Eltwise + Eltwise // Eltwise + Eltwise
int TryMergeEltwiseEltwise(LiteKernel *node, std::set<LiteKernel *> *removed_set, std::vector<LiteKernel *> *nodes) { int TryMergeEltwiseEltwise(LiteKernel *node, std::set<LiteKernel *> *removed_set, std::vector<LiteKernel *> *nodes) {
if (!node->op_parameter()->infer_flag_) {
return RET_ERROR;
}
MS_ASSERT(node); MS_ASSERT(node);
MS_ASSERT(nodes); MS_ASSERT(nodes);
MS_ASSERT(removed_set); MS_ASSERT(removed_set);
@@ -560,6 +591,9 @@ int TryMergeEltwiseEltwise(LiteKernel *node, std::set<LiteKernel *> *removed_set
std::map<lite::Tensor *, FusionEltwiseParameter *> pred_params; std::map<lite::Tensor *, FusionEltwiseParameter *> pred_params;
for (LiteKernel *pred : preds) { for (LiteKernel *pred : preds) {
MS_ASSERT(pred); MS_ASSERT(pred);
if (!pred->op_parameter()->infer_flag_) {
continue;
}
if (AIsInB(pred, nodes) && IsEltwiseAndOperatorSupported(pred) && pred->out_kernels().size() == 1) { if (AIsInB(pred, nodes) && IsEltwiseAndOperatorSupported(pred) && pred->out_kernels().size() == 1) {
auto *tensor = pred->out_tensors().front(); auto *tensor = pred->out_tensors().front();
MS_ASSERT(pred->out_kernels().front() == node); MS_ASSERT(pred->out_kernels().front() == node);
@@ -589,6 +623,9 @@ int TryMergeEltwiseEltwise(LiteKernel *node, std::set<LiteKernel *> *removed_set
} }


void DoSpecificFusion(LiteKernel *node, std::set<LiteKernel *> *removed_set, std::vector<LiteKernel *> *nodes) { void DoSpecificFusion(LiteKernel *node, std::set<LiteKernel *> *removed_set, std::vector<LiteKernel *> *nodes) {
if (!node->op_parameter()->infer_flag_) {
return;
}
switch (node->Type()) { switch (node->Type()) {
case schema::PrimitiveType_Conv2DFusion: case schema::PrimitiveType_Conv2DFusion:
case schema::PrimitiveType_Conv2dTransposeFusion: { case schema::PrimitiveType_Conv2dTransposeFusion: {


+ 7
- 2
mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc View File

@@ -100,6 +100,11 @@ void OpenCLSubGraph::ReplaceOutTensorAndKernelToConvert(const lite::Tensor *in_t
iv->set_out_tensors(tensors); iv->set_out_tensors(tensors);
in_convert_op->AddInKernel(iv); in_convert_op->AddInKernel(iv);
} }
if (in_convert_op->in_kernels().empty()) {
in_convert_op->op_parameter()->infer_flag_ = true;
} else {
in_convert_op->op_parameter()->infer_flag_ = in_opencl_op->in_kernels().front()->op_parameter()->infer_flag_;
}
} }
} }


@@ -143,7 +148,7 @@ int OpenCLSubGraph::GenToFormatOp(const std::vector<lite::Tensor *> &in_tensors,
return RET_ERROR; return RET_ERROR;
} }
parameter->op_parameter.type_ = PRIM_TO_FORMAT; parameter->op_parameter.type_ = PRIM_TO_FORMAT;
parameter->op_parameter.infer_flag_ = false;
parameter->op_parameter.infer_flag_ = false; // infer_flag_ is set in ReplaceOutTensorAndKernelToConvert()
parameter->out_mem_type = mem_type; parameter->out_mem_type = mem_type;
out_parameters->emplace_back(parameter); out_parameters->emplace_back(parameter);
LiteKernel *in_convert_op = nullptr; LiteKernel *in_convert_op = nullptr;
@@ -155,7 +160,7 @@ int OpenCLSubGraph::GenToFormatOp(const std::vector<lite::Tensor *> &in_tensors,
{new_tensor}, {in_tensor}, reinterpret_cast<OpParameter *>(parameter), context_, desc); {new_tensor}, {in_tensor}, reinterpret_cast<OpParameter *>(parameter), context_, desc);
} }
MS_ASSERT(in_convert_op); MS_ASSERT(in_convert_op);
if (in_convert_op == nullptr) {
if (in_convert_op == nullptr || reinterpret_cast<ToFormatOpenCLKernel *>(in_convert_op)->CheckSpecs() != RET_OK) {
MS_LOG(ERROR) << "OpenCLSubGraph create op failed!"; MS_LOG(ERROR) << "OpenCLSubGraph create op failed!";
delete new_tensor; delete new_tensor;
new_tensor = nullptr; new_tensor = nullptr;


Loading…
Cancel
Save