diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/fusion_eltwise.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/fusion_eltwise.cc index ca8189f7ed..5304690109 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/fusion_eltwise.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/fusion_eltwise.cc @@ -18,6 +18,7 @@ #include "src/runtime/kernel/opencl/utils.h" #include "include/errorcode.h" #include "nnacl/fp32/activation_fp32.h" +#include "nnacl/scale.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; @@ -60,12 +61,18 @@ std::pair CheckSupportOrCreateParam( param = reinterpret_cast(eltwise->GetParameter()); eltwise->ClearParameter(); } - } else if (IsArithmetic(node_type)) { - auto act_type = - static_cast(reinterpret_cast(op_parameter)->activation_type_); + } else if (IsArithmetic(node_type) || node_type == schema::PrimitiveType_Scale) { + auto *arith_param = reinterpret_cast(op_parameter); + auto *scale_param = reinterpret_cast(op_parameter); + auto act_type = static_cast( + node_type == schema::PrimitiveType_Scale ? scale_param->activation_type_ : arith_param->activation_type_); EltwiseOperator act_operator = Activation2Operator(act_type); - support = - node->in_tensors().size() == 2 && SupportedOperators.count(operator_) && SupportedOperators.count(act_operator); + support = SupportedOperators.count(operator_) && SupportedOperators.count(act_operator); + if (node_type == schema::PrimitiveType_Scale) { + support = support && node->in_tensors().size() == 3 && scale_param->axis_ == -1; + } else { + support = support && (node->in_tensors().size() == 2); + } if (create_param) { param = new (std::nothrow) FusionEltwiseParameter(operator_, node->name(), node->in_tensors(), replace_map); MS_ASSERT(param); @@ -83,12 +90,6 @@ std::pair CheckSupportOrCreateParam( param = new (std::nothrow) FusionEltwiseParameter(operator_, node->name(), node->in_tensors(), replace_map); MS_ASSERT(param); } - } else if (node_type == schema::PrimitiveType_Scale) { - support = node->in_tensors().size() == 3 && SupportedOperators.count(operator_); - if (create_param) { - param = new (std::nothrow) FusionEltwiseParameter(operator_, node->name(), node->in_tensors(), replace_map); - MS_ASSERT(param); - } } else if (node_type == schema::PrimitiveType_Activation) { auto act_type = static_cast(reinterpret_cast(op_parameter)->type_); EltwiseOperator act_operator = Activation2Operator(act_type); @@ -141,15 +142,11 @@ bool IsEltwiseAndOperatorSupported(LiteKernel *node) { } int FusionEltwiseOpenCLKernel::Prepare() { - static std::set code_map; std::string source = Codegen(); - code_map.insert(source); - - std::string program_name = "FusionEltwise" + std::to_string(code_map.size()); + std::string program_name = "FusionEltwise\n" + source; std::string kernel_name = "FusionEltwise"; ocl_runtime_->LoadSource(program_name, source); ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name); - InitWeights(); SetGlobalLocal(); SetConstArgs(); @@ -390,20 +387,22 @@ std::string FusionEltwiseOpenCLKernel::CodegenCore(FusionEltwiseParameter *param std::string FusionEltwiseOpenCLKernel::GetFormatVarName(std::string name) { if (var_names_.count(name)) { - return name; - } - if (name.empty()) { - name = "_var_" + std::to_string(var_names_.size()); + return simplify_var_name_ ? var_names_[name] : name; } else { - char c = name.front(); - if (c != '_' && !std::isalpha(c)) { - name = '_' + name; + if (name.empty()) { + name = "_var_" + std::to_string(var_names_.size()); + } else { + char c = name.front(); + if (c != '_' && !std::isalpha(c)) { + name = '_' + name; + } + std::replace_if( + name.begin(), name.end(), [](char c) { return !std::isalnum(c); }, '_'); } - std::replace_if( - name.begin(), name.end(), [](char c) { return !std::isalnum(c); }, '_'); + auto new_name = "tmp" + std::to_string(var_names_.size()); + var_names_.emplace(name, new_name); + return simplify_var_name_ ? new_name : name; } - var_names_.insert(name); - return name; } int FusionEltwiseOpenCLKernel::GetTensorIdx(lite::Tensor *in_tensor) { diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/fusion_eltwise.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/fusion_eltwise.h index b7ce6fed0b..6d0d3c58d0 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/fusion_eltwise.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/fusion_eltwise.h @@ -180,7 +180,8 @@ class FusionEltwiseOpenCLKernel : public OpenCLKernel { return shape.empty() || (shape.size() == 1 && shape.front() == 1); } - std::set var_names_; + std::map var_names_; // origin name -> simplified name + const bool simplify_var_name_{true}; std::vector scalar_weights_; std::vector buffer_weights_; };