Browse Source

!15013 【MS】【LITE】【GPU】fix some opencl bugs

From: @wangdongxu6
Reviewed-by: @ddwsky,@zhanghaibo5
Signed-off-by: @ddwsky
pull/15013/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
fa20f96b32
8 changed files with 152 additions and 102 deletions
  1. +95
    -89
      mindspore/lite/src/runtime/kernel/opencl/cl/conv2d.cl
  2. +1
    -3
      mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc
  3. +1
    -1
      mindspore/lite/src/runtime/kernel/opencl/kernel/fusion_eltwise.cc
  4. +2
    -2
      mindspore/lite/src/runtime/kernel/opencl/kernel/strided_slice.cc
  5. +4
    -4
      mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc
  6. +5
    -1
      mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.h
  7. +37
    -0
      mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc
  8. +7
    -2
      mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc

+ 95
- 89
mindspore/lite/src/runtime/kernel/opencl/cl/conv2d.cl View File

@@ -91,11 +91,11 @@ __kernel void Conv2D_H1W1C1(__read_only image2d_t input, __write_only image2d_t
out_h0_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w0_c0));
}

#ifndef EXCEDD_MAX_IMAGE2D_WIDTH
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);
#else
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);
#endif
if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);
} else {
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);
}
}

__kernel void Conv2D_H2W1C1(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight,
@@ -172,17 +172,17 @@ __kernel void Conv2D_H2W1C1(__read_only image2d_t input, __write_only image2d_t
out_h1_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c0));
}

#ifndef EXCEDD_MAX_IMAGE2D_WIDTH
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0);
} // end if (oh1 < OH)
#else
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);
} // end (oh1 < OH)
#endif
if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0);
} // end if (oh1 < OH)
} else {
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);
} // end (oh1 < OH)
}
}

__kernel void Conv2D_H2W1C2(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight,
@@ -283,27 +283,29 @@ __kernel void Conv2D_H2W1C2(__read_only image2d_t input, __write_only image2d_t
out_h1_w0_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c1));
}

#ifndef EXCEDD_MAX_IMAGE2D_WIDTH
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0);
} // end if (oh1 < OH)
if (co_slice1 < CO_SLICES) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1);
if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1);
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0);
} // end if (oh1 < OH)
} // end if (co_slice1 < CO_SLICES)
#else
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);
} // end (oh1 < OH)
WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1);
} // end if (oh1 < OH)
#endif
if (co_slice1 < CO_SLICES) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1);
} // end if (oh1 < OH)
} // end if (co_slice1 < CO_SLICES)
} else {
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);
} // end (oh1 < OH)
if (co_slice1 < CO_SLICES) {
WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1);
} // end if (oh1 < OH)
} // end if (co_slice1 < CO_SLICES)
}
}

__kernel void Conv2D_H2W2C2(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight,
@@ -454,35 +456,37 @@ __kernel void Conv2D_H2W2C2(__read_only image2d_t input, __write_only image2d_t
out_h1_w1_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w1_c1));
}

#ifndef EXCEDD_MAX_IMAGE2D_WIDTH
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh0), out_h0_w1_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh1), out_h1_w1_c0);
} // end if (oh1 < OH)
if (co_slice1 < CO_SLICES) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh0), out_h0_w1_c1);
if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh0), out_h0_w1_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh1), out_h1_w1_c1);
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh1), out_h1_w1_c0);
} // end if (oh1 < OH)
} // end if (co_slice1 < CO_SLICES)
#else
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow1), out_h0_w1_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow1), out_h1_w1_c0);
} // end (oh1 < OH)
WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1);
WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow1), out_h0_w1_c1);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1);
WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow1), out_h1_w1_c1);
} // end if (oh1 < OH)
#endif
if (co_slice1 < CO_SLICES) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh0), out_h0_w1_c1);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh1), out_h1_w1_c1);
} // end if (oh1 < OH)
} // end if (co_slice1 < CO_SLICES)
} else {
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow1), out_h0_w1_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow1), out_h1_w1_c0);
} // end (oh1 < OH)
if (co_slice1 < CO_SLICES) {
WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1);
WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow1), out_h0_w1_c1);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1);
WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow1), out_h1_w1_c1);
} // end if (oh1 < OH)
} // end if (co_slice1 < CO_SLICES)
}
}

__kernel void Conv2D_H2W2C2_Img(__read_only image2d_t input, __write_only image2d_t output,
@@ -640,33 +644,35 @@ __kernel void Conv2D_H2W2C2_Img(__read_only image2d_t input, __write_only image2
out_h1_w1_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w1_c1));
}

#ifndef EXCEDD_MAX_IMAGE2D_WIDTH
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh0), out_h0_w1_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh1), out_h1_w1_c0);
} // end if (oh1 < OH)
if (co_slice1 < CO_SLICES) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh0), out_h0_w1_c1);
if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh0), out_h0_w1_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh1), out_h1_w1_c1);
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh1), out_h1_w1_c0);
} // end if (oh1 < OH)
} // end if (co_slice1 < CO_SLICES)
#else
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow1), out_h0_w1_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow1), out_h1_w1_c0);
} // end (oh1 < OH)
WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1);
WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow1), out_h0_w1_c1);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1);
WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow1), out_h1_w1_c1);
} // end if (oh1 < OH)
#endif
if (co_slice1 < CO_SLICES) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh0), out_h0_w1_c1);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh1), out_h1_w1_c1);
} // end if (oh1 < OH)
} // end if (co_slice1 < CO_SLICES)
} else {
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow1), out_h0_w1_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow1), out_h1_w1_c0);
} // end (oh1 < OH)
if (co_slice1 < CO_SLICES) {
WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1);
WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow1), out_h0_w1_c1);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1);
WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow1), out_h1_w1_c1);
} // end if (oh1 < OH)
} // end if (co_slice1 < CO_SLICES)
}
}

+ 1
- 3
mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc View File

@@ -144,9 +144,7 @@ void Conv2DOpenCLKernel::BuildKernel() {
kernel_name << "_Img";
}
ocl_runtime_->LoadSource(program_name, GetActDefines() + conv2d_source);
std::string build_option =
(OW_ * CO_SLICES_ <= ocl_runtime_->GetMaxImage2DWidth()) ? "" : " -DEXCEDD_MAX_IMAGE2D_WIDTH";
ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name.str(), {build_option});
ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name.str());
}

void Conv2DOpenCLKernel::SetBlockSize() {


+ 1
- 1
mindspore/lite/src/runtime/kernel/opencl/kernel/fusion_eltwise.cc View File

@@ -379,7 +379,7 @@ std::string FusionEltwiseOpenCLKernel::CodegenCore(FusionEltwiseParameter *param
code << cl_prefix << "FLT4 " << exp0 << " = exp(" + var0 + ");\n";
code << cl_prefix << "FLT4 " << exp1 << " = exp(-" + var0 + ");\n";
code << cl_prefix << "FLT4 " << out_name << " = (" << exp0 << " - " << exp1 << ") / (" << exp0 << " + " << exp1
<< "));\n";
<< ");\n";
}
}



+ 2
- 2
mindspore/lite/src/runtime/kernel/opencl/kernel/strided_slice.cc View File

@@ -143,11 +143,11 @@ int StridedSliceOpenCLKernel::InitConstArgs() {
// avoid begin is out of range
begin_.s[i] = std::clamp(begin_.s[i], 0, input_shape_.s[i] - 1);
// end is negative
if (end_.s[i] < 0) {
if (end_.s[i] <= 0) {
end_.s[i] += input_shape_.s[i];
}
// avoid end is out of range
end_.s[i] = std::clamp(end_.s[i], -1, input_shape_.s[i]);
end_.s[i] = std::clamp(end_.s[i], 0, input_shape_.s[i]);

// check stride begin end
if (stride_.s[i] > 0) {


+ 4
- 4
mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc View File

@@ -42,8 +42,6 @@ int ToFormatOpenCLKernel::CheckSpecs() {
MS_LOG(ERROR) << "Unsupported data type " << data_type;
return RET_ERROR;
}
auto parameter = reinterpret_cast<OpenCLToFormatParameter *>(op_parameter_);
out_mem_type_ = parameter->out_mem_type;
return RET_OK;
}

@@ -103,8 +101,10 @@ int ToFormatOpenCLKernel::Run() {
}

int ToFormatOpenCLKernel::InferShape() {
out_tensors_[0]->set_shape(in_tensors_[0]->shape());
op_parameter_->infer_flag_ = false;
if (!op_parameter_->infer_flag_) {
op_parameter_->infer_flag_ = true;
out_tensors_.front()->set_shape(in_tensors_.front()->shape());
}
return RET_OK;
}



+ 5
- 1
mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.h View File

@@ -25,7 +25,11 @@
namespace mindspore::kernel {
class ToFormatOpenCLKernel : public OpenCLKernel {
public:
using OpenCLKernel::OpenCLKernel;
ToFormatOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: OpenCLKernel(parameter, inputs, outputs, ctx) {
out_mem_type_ = reinterpret_cast<OpenCLToFormatParameter *>(op_parameter_)->out_mem_type;
}
~ToFormatOpenCLKernel() override = default;

int Run() override;


+ 37
- 0
mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc View File

@@ -135,6 +135,8 @@ std::vector<T *> RemoveDuplicationsButKeepOrder(const std::vector<T *> &vec) {
void Merge(LiteKernel *a, LiteKernel *b, bool remove_a) {
MS_ASSERT(a);
MS_ASSERT(b);
MS_ASSERT(a->op_parameter()->infer_flag_);
MS_ASSERT(b->op_parameter()->infer_flag_);
if (remove_a) { // pred->tensor0->a->tensor1->b: remove a tensor1
// update pred out_kernels: a.in_kernels.out_kernels.replace(a,b)
for (auto *pred : a->in_kernels()) {
@@ -231,6 +233,9 @@ void TryMergePadXxx(LiteKernel *node, std::set<LiteKernel *> *removed_set, std::
}
LiteKernel *pad = node->in_kernels().front();
MS_ASSERT(pad);
if (!pad->op_parameter()->infer_flag_) {
return;
}
if (pad->in_tensors().front()->shape().size() != 4) {
return;
}
@@ -264,6 +269,10 @@ void TryMergeConvReshape(LiteKernel *reshape, std::set<LiteKernel *> *removed_se
// group must be 1
LiteKernel *conv = reshape->in_kernels().front();
MS_ASSERT(conv);

if (!conv->op_parameter()->infer_flag_) {
return;
}
auto *param = reinterpret_cast<ConvParameter *>(reinterpret_cast<OpenCLKernel *>(conv)->GetParameter());
MS_ASSERT(param);
if (param->group_ != 1) {
@@ -286,6 +295,9 @@ void TryMergeFcReshape(LiteKernel *reshape, std::set<LiteKernel *> *removed_set,
bool NC_N11C_flag = NC_N11C(reshape);
if (NC_N11C_flag || N11C_NC(reshape)) {
LiteKernel *fc = reshape->in_kernels().front();
if (!fc->op_parameter()->infer_flag_) {
return;
}
MS_ASSERT(fc);
MergeRemoveB(fc, reshape, removed_set);
MS_LOG(DEBUG) << "Merge FullConnection and Reshape" + (NC_N11C_flag ? std::string("(NC->N11C)") : "(N11C->NC)") +
@@ -303,6 +315,9 @@ void TryMergeReshapeFc(LiteKernel *fc, std::set<LiteKernel *> *removed_set, std:
}
LiteKernel *reshape = fc->in_kernels().front();
MS_ASSERT(reshape);
if (!reshape->op_parameter()->infer_flag_) {
return;
}
bool NC11_NC_flag = NC11_NC(reshape);
if (NC11_NC_flag || NC_N11C(reshape)) {
MergeRemoveA(reshape, fc, removed_set);
@@ -317,6 +332,9 @@ void TryMergeArithmeticAct(LiteKernel *act, std::set<LiteKernel *> *removed_set)
MS_ASSERT(removed_set);
LiteKernel *arithmetic = act->in_kernels().front();
MS_ASSERT(arithmetic);
if (!arithmetic->op_parameter()->infer_flag_) {
return;
}
auto *arithmetic_param =
reinterpret_cast<ArithmeticParameter *>(reinterpret_cast<OpenCLKernel *>(arithmetic)->GetParameter());
auto *act_param = reinterpret_cast<ActivationParameter *>(reinterpret_cast<OpenCLKernel *>(act)->GetParameter());
@@ -339,6 +357,10 @@ void TryMergeXxxActivation(LiteKernel *act, std::set<LiteKernel *> *removed_set)
MS_ASSERT(removed_set);
auto *act_param = reinterpret_cast<ActivationParameter *>(reinterpret_cast<OpenCLKernel *>(act)->GetParameter());
LiteKernel *node = act->in_kernels().front();
MS_ASSERT(node);
if (!node->op_parameter()->infer_flag_) {
return;
}
auto *param = reinterpret_cast<ParamType *>(reinterpret_cast<OpenCLKernel *>(node)->GetParameter());
MS_ASSERT(param);

@@ -379,6 +401,9 @@ void TryMergeConvPReLU(LiteKernel *prelu, std::set<LiteKernel *> *removed_set, s
}
LiteKernel *conv = prelu->in_kernels().front();
MS_ASSERT(conv);
if (!conv->op_parameter()->infer_flag_) {
return;
}
if (reinterpret_cast<Conv2DOpenCLKernel *>(conv)->use_winograd_) {
return;
}
@@ -475,6 +500,9 @@ void TryMergeDeconvScale(LiteKernel *scale, std::set<LiteKernel *> *removed_set,
}
LiteKernel *deconv = scale->in_kernels().front();
MS_ASSERT(deconv);
if (!deconv->op_parameter()->infer_flag_) {
return;
}

// check act_type_
auto *deconv_param = reinterpret_cast<ConvParameter *>(reinterpret_cast<OpenCLKernel *>(deconv)->GetParameter());
@@ -546,6 +574,9 @@ void CreateEltwiseKernelReplaceOld(FusionEltwiseParameter *param, LiteKernel *ol

// Eltwise + Eltwise
int TryMergeEltwiseEltwise(LiteKernel *node, std::set<LiteKernel *> *removed_set, std::vector<LiteKernel *> *nodes) {
if (!node->op_parameter()->infer_flag_) {
return RET_ERROR;
}
MS_ASSERT(node);
MS_ASSERT(nodes);
MS_ASSERT(removed_set);
@@ -560,6 +591,9 @@ int TryMergeEltwiseEltwise(LiteKernel *node, std::set<LiteKernel *> *removed_set
std::map<lite::Tensor *, FusionEltwiseParameter *> pred_params;
for (LiteKernel *pred : preds) {
MS_ASSERT(pred);
if (!pred->op_parameter()->infer_flag_) {
continue;
}
if (AIsInB(pred, nodes) && IsEltwiseAndOperatorSupported(pred) && pred->out_kernels().size() == 1) {
auto *tensor = pred->out_tensors().front();
MS_ASSERT(pred->out_kernels().front() == node);
@@ -589,6 +623,9 @@ int TryMergeEltwiseEltwise(LiteKernel *node, std::set<LiteKernel *> *removed_set
}

void DoSpecificFusion(LiteKernel *node, std::set<LiteKernel *> *removed_set, std::vector<LiteKernel *> *nodes) {
if (!node->op_parameter()->infer_flag_) {
return;
}
switch (node->Type()) {
case schema::PrimitiveType_Conv2DFusion:
case schema::PrimitiveType_Conv2dTransposeFusion: {


+ 7
- 2
mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc View File

@@ -100,6 +100,11 @@ void OpenCLSubGraph::ReplaceOutTensorAndKernelToConvert(const lite::Tensor *in_t
iv->set_out_tensors(tensors);
in_convert_op->AddInKernel(iv);
}
if (in_convert_op->in_kernels().empty()) {
in_convert_op->op_parameter()->infer_flag_ = true;
} else {
in_convert_op->op_parameter()->infer_flag_ = in_opencl_op->in_kernels().front()->op_parameter()->infer_flag_;
}
}
}

@@ -143,7 +148,7 @@ int OpenCLSubGraph::GenToFormatOp(const std::vector<lite::Tensor *> &in_tensors,
return RET_ERROR;
}
parameter->op_parameter.type_ = PRIM_TO_FORMAT;
parameter->op_parameter.infer_flag_ = false;
parameter->op_parameter.infer_flag_ = false; // infer_flag_ is set in ReplaceOutTensorAndKernelToConvert()
parameter->out_mem_type = mem_type;
out_parameters->emplace_back(parameter);
LiteKernel *in_convert_op = nullptr;
@@ -155,7 +160,7 @@ int OpenCLSubGraph::GenToFormatOp(const std::vector<lite::Tensor *> &in_tensors,
{new_tensor}, {in_tensor}, reinterpret_cast<OpParameter *>(parameter), context_, desc);
}
MS_ASSERT(in_convert_op);
if (in_convert_op == nullptr) {
if (in_convert_op == nullptr || reinterpret_cast<ToFormatOpenCLKernel *>(in_convert_op)->CheckSpecs() != RET_OK) {
MS_LOG(ERROR) << "OpenCLSubGraph create op failed!";
delete new_tensor;
new_tensor = nullptr;


Loading…
Cancel
Save