diff --git a/mindspore/lite/src/runtime/kernel/arm/base/slice_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/slice_base.cc index 43c4bac9c5..6d9846d204 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/slice_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/slice_base.cc @@ -76,18 +76,12 @@ int SliceCPUKernel::SliceParallelRun(int thread_id) { } int SliceCPUKernel::Run() { - auto ret = PreProcess(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "PreProcess fail!ret: " << ret; - return ret; - } - if (param_->size_[1] < op_parameter_->thread_num_) { DoSliceNoParallel(in_tensors_.at(0)->data_c(), out_tensors_.at(0)->data_c(), param_, lite::DataTypeSize(in_tensors_.at(0)->data_type())); return RET_OK; } - ret = ParallelLaunch(this->context_->thread_pool_, SliceLaunch, this, op_parameter_->thread_num_); + auto ret = ParallelLaunch(this->context_->thread_pool_, SliceLaunch, this, op_parameter_->thread_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "slice launch fail!ret: " << ret; return RET_ERROR; @@ -96,6 +90,5 @@ int SliceCPUKernel::Run() { } REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_SliceFusion, LiteKernelCreator) -REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_SliceFusion, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SliceFusion, LiteKernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/slice_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/slice_fp16.cc new file mode 100644 index 0000000000..f878bd13f8 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/slice_fp16.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp16/slice_fp16.h" +#include "src/kernel_registry.h" +#include "nnacl/base/slice_base.h" +#include "nnacl/fp16/cast_fp16.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_SliceFusion; + +namespace mindspore::kernel { +int SliceFp16Launch(void *cdata, int task_id) { + if (cdata == nullptr) { + MS_LOG(ERROR) << "Input cdata is nullptr!"; + return RET_ERROR; + } + auto kernel = reinterpret_cast(cdata); + return kernel->SliceFp16ParallelRun(task_id); +} + +SliceFp16CPUKernel::~SliceFp16CPUKernel() { + if (input_data_ != nullptr) { + context_->allocator->Free(input_data_); + input_data_ = nullptr; + } +} + +int SliceFp16CPUKernel::Init() { + auto input_tensor = in_tensors_.at(0); + if (input_tensor->data_type() == kNumberTypeFloat32 && input_tensor->data_c() != nullptr) { + input_data_ = + reinterpret_cast(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t))); + Float32ToFloat16(reinterpret_cast(input_tensor->data_c()), input_data_, input_tensor->ElementsNum()); + } + return SliceCPUKernel::Init(); +} + +int SliceFp16CPUKernel::SliceFp16ParallelRun(int thread_id) { + void *input_data = input_data_ == nullptr ? in_tensors_.at(0)->data_c() : input_data_; + DoSlice(input_data, out_tensors_.at(0)->data_c(), param_, thread_id, lite::DataTypeSize(kNumberTypeFloat16)); + return RET_OK; +} + +int SliceFp16CPUKernel::Run() { + void *input_data = input_data_ == nullptr ? in_tensors_.at(0)->data_c() : input_data_; + if (param_->size_[1] < op_parameter_->thread_num_) { + DoSliceNoParallel(input_data, out_tensors_.at(0)->data_c(), param_, lite::DataTypeSize(kNumberTypeFloat16)); + return RET_OK; + } + auto ret = ParallelLaunch(this->context_->thread_pool_, SliceFp16Launch, this, op_parameter_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "fp16 slice launch fail!ret: " << ret; + return RET_ERROR; + } + return RET_OK; +} + +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_SliceFusion, LiteKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/slice_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/slice_fp16.h new file mode 100644 index 0000000000..cb99cddfa3 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/slice_fp16.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SLICE_FP16_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SLICE_FP16_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/base/slice_base.h" + +namespace mindspore::kernel { +class SliceFp16CPUKernel : public SliceCPUKernel { + public: + SliceFp16CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx) + : SliceCPUKernel(parameter, inputs, outputs, ctx) {} + ~SliceFp16CPUKernel() override; + + int Init() override; + int Run() override; + int SliceFp16ParallelRun(int thread_id); + + private: + float16_t *input_data_ = nullptr; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SLICE_FP16_H_ diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 02e7cfe642..6b40ca78cf 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -243,6 +243,9 @@ if(ENABLE_CONVERTER) ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/matmul_add_fusion.cc + ${LITE_DIR}/tools/optimizer/fusion/gelu_fusion.cc + ${LITE_DIR}/tools/optimizer/fusion/tf_gelu_fusion.cc + ${LITE_DIR}/tools/optimizer/fusion/onnx_gelu_fusion.cc ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 62d9536892..1dc267e005 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -98,7 +98,7 @@ void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { MS_LOG(ERROR) << "value node is invalid."; return; } - if (value_node->value() != nullptr && (opt::CheckPrimitiveType(make_tuple_node, opt::kPrimMakeTuple) || + if (value_node->value() != nullptr && (opt::CheckPrimitiveType(make_tuple_node, prim::kPrimMakeTuple) || opt::CheckPrimitiveType(make_tuple_node, opt::kPrimMakeTupleV2))) { has_make_tuple = true; for (size_t j = 1; j < make_tuple_node->inputs().size(); ++j) { @@ -372,7 +372,7 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptrname = mindspore::ops::kNameReturn; ret = SetGraphoutputIndex(cnode, subgraph_index, meta_graphT, sub_graphT, node.get()); if (ret != RET_OK) { diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 575b845cbd..e181a48d4b 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -53,6 +53,9 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/fusion/tf_bidirection_gru_fusion.cc ../optimizer/fusion/tf_bidirection_gru_cf_fusion.cc ../optimizer/fusion/matmul_add_fusion.cc + ../optimizer/fusion/gelu_fusion.cc + ../optimizer/fusion/tf_gelu_fusion.cc + ../optimizer/fusion/onnx_gelu_fusion.cc ../optimizer/graph/weight_format_transform_pass.cc ../optimizer/graph/weight_format_hardcode_pass.cc ../optimizer/graph/clip_convert_activation_pass.cc diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index c9bd5547d8..24d32a424e 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -37,6 +37,8 @@ #include "tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.h" #include "tools/optimizer/fusion/matmul_add_fusion.h" #include "tools/optimizer/graph/primitive_adjust_pass.h" +#include "tools/optimizer/fusion/tf_gelu_fusion.h" +#include "tools/optimizer/fusion/onnx_gelu_fusion.h" #include "tools/optimizer/graph/mindir_adjust_pass.h" #include "tools/optimizer/graph/redundant_op_remove_pass.h" #include "tools/optimizer/graph/weight_format_hardcode_pass.h" @@ -89,6 +91,8 @@ int AnfTransform::AddFusionPass(const std::shared_ptr &opti fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); + fusion_pm->AddPass(std::make_shared()); + fusion_pm->AddPass(std::make_shared()); } if (config->fmk == lite::converter::FmkType_MS) { auto remove_unused_cast_pass = std::make_shared(); diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index dc3edb7e66..d30b0c36d3 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -54,10 +54,10 @@ bool IsRealKernel(const AnfNodePtr &node) { auto input = cnode->inputs()[0]; bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) || IsPrimitive(input, prim::kPrimTensorSummary) || - IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, kPrimMakeTuple) || + IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) || IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) || IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) || - IsPrimitive(input, kPrimReturn) || IsPrimitive(input, prim::kPrimPartial); + IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial); return !is_virtual_node; } @@ -335,7 +335,7 @@ bool IsRealCNodeKernel(const AnfNodePtr &node) { return false; } // return considered as a real node - if (CheckPrimitiveType(node, kPrimReturn)) { + if (CheckPrimitiveType(node, prim::kPrimReturn)) { return true; } return IsRealKernel(node); diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.h b/mindspore/lite/tools/optimizer/common/gllo_utils.h index 1db8048ec7..880417b44d 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.h +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.h @@ -35,8 +35,8 @@ using mindspore::lite::RET_OK; using mindspore::lite::STATUS; namespace mindspore { namespace opt { -inline const PrimitivePtr kPrimReturn = std::make_shared("Return"); -inline const PrimitivePtr kPrimMakeTuple = std::make_shared("MakeTuple"); +inline const PrimitivePtr kPrimDivFusion = std::make_shared("DivFusion"); +inline const PrimitivePtr kPrimErf = std::make_shared("Erf"); inline const PrimitivePtr kPrimMakeTupleV2 = std::make_shared("make_tuple"); inline const PrimitivePtr kPrimIdentity = std::make_shared("Identity"); std::vector CastToInt(const ValuePtr &value); @@ -145,6 +145,15 @@ ParameterPtr BuildIntVec2DParameterNode(const FuncGraphPtr &func_graph, const st ParameterPtr BuildFloatValueParameterNode(const FuncGraphPtr &func_graph, const float &data, const std::string &node_name); + +template +inline bool IsSpecifiedNode(const BaseRef &n) { + if (utils::isa(n)) { + auto anf_node = utils::cast(n); + return CheckPrimitiveType(anf_node, *prim); + } + return false; +} } // namespace opt } // namespace mindspore #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_GLLO_UTILS_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/gelu_fusion.cc b/mindspore/lite/tools/optimizer/fusion/gelu_fusion.cc new file mode 100644 index 0000000000..3545f41dea --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/gelu_fusion.cc @@ -0,0 +1,85 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/optimizer/fusion/gelu_fusion.h" +#include +#include +#include "ops/fusion/activation.h" +#include "utils/utils.h" +#include "tools/optimizer/common/gllo_utils.h" + +namespace mindspore { +namespace opt { +CNodePtr GeLUFusion::CreateGeLUNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(node != nullptr); + auto gelu_prim = std::make_shared(); + gelu_prim->set_activation_type(mindspore::GELU); + auto input_node = utils::cast((*equiv)[input_]); + MS_ASSERT(input_node != nullptr); + auto gelu_cnode = func_graph->NewCNode(gelu_prim, {input_node}); + gelu_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_gelu"); + gelu_cnode->set_abstract(node->abstract()->Clone()); + return gelu_cnode; +} + +const float GeLUFusion::GetParameterValue(const EquivPtr &equiv, const VarPtr &input) const { + MS_ASSERT(equiv != nullptr); + MS_ASSERT(input != nullptr); + float value = -1; + auto node = utils::cast((*equiv)[input]); + if (node == nullptr || !utils::isa(node)) { + return value; + } + auto parameter_node = node->cast(); + if (!parameter_node->has_default() || parameter_node->default_param() == nullptr) { + return value; + } + auto param_value_lite = parameter_node->default_param()->cast(); + if (param_value_lite == nullptr) { + return value; + } + if (param_value_lite->tensor_type() != kNumberTypeFloat32 && param_value_lite->tensor_type() != kNumberTypeFloat) { + return value; + } + if (param_value_lite->tensor_size() != sizeof(float)) { + return value; + } + return *static_cast(param_value_lite->tensor_addr()); +} + +const AnfNodePtr GeLUFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(node != nullptr); + MS_ASSERT(equiv != nullptr); + MS_LOG(DEBUG) << "gelu_fusion pass"; + if (!utils::isa(node)) { + return nullptr; + } + if (!CheckPattern(equiv)) { + return nullptr; + } + auto cnode = CreateGeLUNode(func_graph, node, equiv); + if (cnode == nullptr) { + MS_LOG(DEBUG) << "new gelu node failed."; + return nullptr; + } + return cnode; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/fusion/gelu_fusion.h b/mindspore/lite/tools/optimizer/fusion/gelu_fusion.h new file mode 100644 index 0000000000..f020155273 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/gelu_fusion.h @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_GELU_FUSION_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_GELU_FUSION_H_ + +#include +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "utils/utils.h" +#include "tools/optimizer/common/gllo_utils.h" + +namespace mindspore { +namespace opt { +class GeLUFusion : public PatternProcessPass { + public: + explicit GeLUFusion(const std::string &name = "gelu_fusion", bool multigraph = true) + : PatternProcessPass(name, multigraph), input_(std::make_shared()) {} + + ~GeLUFusion() override = default; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + protected: + virtual bool CheckPattern(const EquivPtr &equiv) const = 0; + const float GetParameterValue(const EquivPtr &equiv, const VarPtr &input) const; + VarPtr input_ = nullptr; + + private: + CNodePtr CreateGeLUNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv) const; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_GELU_FUSION_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/matmul_add_fusion.cc b/mindspore/lite/tools/optimizer/fusion/matmul_add_fusion.cc index ec068e21a5..a96d8dff16 100644 --- a/mindspore/lite/tools/optimizer/fusion/matmul_add_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/matmul_add_fusion.cc @@ -20,19 +20,19 @@ namespace mindspore { namespace opt { namespace { -constexpr size_t AddInputSize = 3; -constexpr size_t MatMulInputSize = 3; +constexpr size_t kAddInputSize = 3; +constexpr size_t kMatMulInputSize = 3; bool CheckAndGetMatMulIndex(const CNodePtr &cnode, size_t *index) { MS_ASSERT(cnode != nullptr); MS_ASSERT(index != nullptr); - if (cnode->size() != AddInputSize) { + if (cnode->size() != kAddInputSize) { return false; } size_t matmul_index = 0; for (size_t i = 1; i < cnode->size(); ++i) { if (CheckPrimitiveType(cnode->input(i), prim::kPrimMatMul)) { auto matmul_cnode = cnode->input(i)->cast(); - if (matmul_cnode->size() > MatMulInputSize) { + if (matmul_cnode->size() > kMatMulInputSize) { continue; } matmul_index = i; @@ -63,7 +63,7 @@ bool MatMulAddFusion::Run(const FuncGraphPtr &func_graph) { continue; } auto matmul_cnode = cnode->input(index)->cast(); - auto bias_node = cnode->input(AddInputSize - index); + auto bias_node = cnode->input(kAddInputSize - index); if (!utils::isa(bias_node) || !bias_node->cast()->default_param()) { continue; } diff --git a/mindspore/lite/tools/optimizer/fusion/norm_fusion.cc b/mindspore/lite/tools/optimizer/fusion/norm_fusion.cc index 0676f86d45..6d5d0ff677 100644 --- a/mindspore/lite/tools/optimizer/fusion/norm_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/norm_fusion.cc @@ -17,7 +17,6 @@ #include #include "ops/fusion/layer_norm_fusion.h" #include "ops/fusion/reduce_fusion.h" -#include "ops/rsqrt.h" #include "mindspore/core/ops/instance_norm.h" #include "src/param_value_lite.h" #include "utils/utils.h" @@ -27,60 +26,6 @@ namespace mindspore { namespace opt { namespace { -inline bool IsAddNode(const BaseRef &n) { - if (utils::isa(n)) { - return CheckPrimitiveType(utils::cast(n), prim::kPrimAddFusion); - } - return false; -} - -inline bool IsSquaredDifferenceNode(const BaseRef &n) { - if (utils::isa(n)) { - return CheckPrimitiveType(utils::cast(n), prim::kPrimSquaredDifference); - } - return false; -} - -inline bool IsRsqrtNode(const BaseRef &n) { - if (utils::isa(n)) { - return CheckPrimitiveType(utils::cast(n), prim::kPrimRsqrt); - } - return false; -} - -inline bool IsMulNode(const BaseRef &n) { - if (utils::isa(n)) { - return CheckPrimitiveType(utils::cast(n), prim::kPrimMulFusion); - } - return false; -} - -inline bool IsSubNode(const BaseRef &n) { - if (utils::isa(n)) { - return CheckPrimitiveType(utils::cast(n), prim::kPrimSubFusion); - } - return false; -} -inline bool IsPowNode(const BaseRef &n) { - if (utils::isa(n)) { - return CheckPrimitiveType(utils::cast(n), prim::kPrimPowFusion); - } - return false; -} -inline bool IsSqrtNode(const BaseRef &n) { - if (utils::isa(n)) { - return CheckPrimitiveType(utils::cast(n), prim::kPrimSqrt); - } - return false; -} -inline bool IsDivNode(const BaseRef &n) { - if (utils::isa(n)) { - return CheckPrimitiveType(utils::cast(n), prim::kPrimDiv) || - CheckPrimitiveType(utils::cast(n), std::make_shared("DivFusion")); - } - return false; -} - STATUS GetReduceAxes(const BaseRef &n, std::vector *axes) { MS_ASSERT(node != nullptr); if (utils::isa(n)) { @@ -195,7 +140,7 @@ bool NormFusion::GetNormTypeAndAxis(const CNodePtr &input_cnode, const std::vect } } if (mean_axes.back() >= 0 && mean_axes.back() + 1 != static_cast(shape.size())) { - MS_LOG(DEBUG) << "mean node is not reduce to last axis"; + MS_LOG(DEBUG) << "mean node is not reduce to last axis."; return false; } @@ -318,37 +263,41 @@ const AnfNodePtr NormFusion::Process(const FuncGraphPtr &func_graph, const AnfNo const BaseRef TfNormFusion::DefinePattern() const { VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_}); - auto squared_diffference1 = std::make_shared(IsSquaredDifferenceNode); + auto squared_diffference1 = std::make_shared(IsSpecifiedNode<&prim::kPrimSquaredDifference>); VectorRef squared_diffference1_ref = VectorRef({squared_diffference1, input_, mean1_ref}); - auto mul1 = std::make_shared(IsMulNode); + auto mul1 = std::make_shared(IsSpecifiedNode<&prim::kPrimMulFusion>); VectorRef mean2_ref = VectorRef({mean2_, squared_diffference1_ref, mean2_axes_}); - auto add1 = std::make_shared(IsAddNode); + auto add1 = std::make_shared(IsSpecifiedNode<&prim::kPrimAddFusion>); VectorRef add1_ref = VectorRef({add1, mean2_ref, epsilon_}); - auto rsqrt1 = std::make_shared(IsRsqrtNode); + auto rsqrt1 = std::make_shared(IsSpecifiedNode<&prim::kPrimRsqrt>); VectorRef rsqrt1_ref = VectorRef({rsqrt1, add1_ref}); - auto mul2 = std::make_shared(IsMulNode); + auto mul2 = std::make_shared(IsSpecifiedNode<&prim::kPrimMulFusion>); VectorRef mul2_ref = VectorRef({mul2, rsqrt1_ref, gamma_}); VectorRef mul1_ref = VectorRef({mul1, input_, mul2_ref}); - auto mul3 = std::make_shared(IsMulNode); + auto mul3 = std::make_shared(IsSpecifiedNode<&prim::kPrimMulFusion>); VectorRef mul3_ref = VectorRef({mul3, mean1_ref, mul2_ref}); - auto sub1 = std::make_shared(IsSubNode); + auto sub1 = std::make_shared(IsSpecifiedNode<&prim::kPrimSubFusion>); VectorRef sub1_ref = VectorRef({sub1, beta_, mul3_ref}); - auto add2 = std::make_shared(IsAddNode); + auto add2 = std::make_shared(IsSpecifiedNode<&prim::kPrimAddFusion>); VectorRef add2_ref = VectorRef({add2, mul1_ref, sub1_ref}); return add2_ref; } const BaseRef OnnxLayerNormFusion::DefinePattern() const { VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_}); - VectorRef sub1_ref = VectorRef({std::make_shared(IsSubNode), input_, mean1_ref}); - VectorRef sub2_ref = VectorRef({std::make_shared(IsSubNode), input_, mean1_ref}); - VectorRef pow_ref = VectorRef({std::make_shared(IsPowNode), sub2_ref, std::make_shared()}); + VectorRef sub1_ref = + VectorRef({std::make_shared(IsSpecifiedNode<&prim::kPrimSubFusion>), input_, mean1_ref}); + VectorRef sub2_ref = + VectorRef({std::make_shared(IsSpecifiedNode<&prim::kPrimSubFusion>), input_, mean1_ref}); + VectorRef pow_ref = + VectorRef({std::make_shared(IsSpecifiedNode<&prim::kPrimPowFusion>), sub2_ref, std::make_shared()}); VectorRef mean2_ref = VectorRef({mean2_, pow_ref, mean2_axes_}); - VectorRef add1_ref = VectorRef({std::make_shared(IsAddNode), mean2_ref, epsilon_}); - VectorRef sqrt_ref = VectorRef({std::make_shared(IsSqrtNode), add1_ref}); - VectorRef div_ref = VectorRef({std::make_shared(IsDivNode), sub1_ref, sqrt_ref}); - VectorRef mul_ref = VectorRef({std::make_shared(IsMulNode), gamma_, div_ref}); - VectorRef add2_ref = VectorRef({std::make_shared(IsAddNode), mul_ref, beta_}); + VectorRef add1_ref = + VectorRef({std::make_shared(IsSpecifiedNode<&prim::kPrimAddFusion>), mean2_ref, epsilon_}); + VectorRef sqrt_ref = VectorRef({std::make_shared(IsSpecifiedNode<&prim::kPrimSqrt>), add1_ref}); + VectorRef div_ref = VectorRef({std::make_shared(IsSpecifiedNode<&kPrimDivFusion>), sub1_ref, sqrt_ref}); + VectorRef mul_ref = VectorRef({std::make_shared(IsSpecifiedNode<&prim::kPrimMulFusion>), gamma_, div_ref}); + VectorRef add2_ref = VectorRef({std::make_shared(IsSpecifiedNode<&prim::kPrimAddFusion>), mul_ref, beta_}); return add2_ref; } } // namespace opt diff --git a/mindspore/lite/tools/optimizer/fusion/norm_fusion.h b/mindspore/lite/tools/optimizer/fusion/norm_fusion.h index 44e9410e0c..683cc5fa2f 100644 --- a/mindspore/lite/tools/optimizer/fusion/norm_fusion.h +++ b/mindspore/lite/tools/optimizer/fusion/norm_fusion.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_ -#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_ +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_NORM_FUSION_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_NORM_FUSION_H_ #include #include @@ -31,7 +31,7 @@ namespace opt { /// fuse layer_norm or instance_norm into one operator class NormFusion : public PatternProcessPass { public: - explicit NormFusion(const std::string &name = "tf_norm_fusion", bool multigraph = true) + explicit NormFusion(const std::string &name = "norm_fusion", bool multigraph = true) : PatternProcessPass(name, multigraph) { input_ = std::make_shared(); mean1_ = std::make_shared(); @@ -44,7 +44,6 @@ class NormFusion : public PatternProcessPass { } ~NormFusion() override = default; - virtual const BaseRef DefinePattern() const = 0; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; private: @@ -70,6 +69,9 @@ class NormFusion : public PatternProcessPass { /// fuse tf layer_norm or instance_norm into one operator class TfNormFusion : public NormFusion { public: + explicit TfNormFusion(const std::string &name = "tf_norm_fusion", bool multigraph = true) + : NormFusion(name, multigraph) {} + ~TfNormFusion() override = default; const BaseRef DefinePattern() const override; }; @@ -77,11 +79,13 @@ class TfNormFusion : public NormFusion { /// fuse onnx layer_norm into one operator class OnnxLayerNormFusion : public NormFusion { public: + explicit OnnxLayerNormFusion(const std::string &name = "onnx_layer_norm_fusion", bool multigraph = true) + : NormFusion(name, multigraph) {} + ~OnnxLayerNormFusion() override = default; const BaseRef DefinePattern() const override; }; - } // namespace opt } // namespace mindspore -#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_ +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_NORM_FUSION_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/onnx_gelu_fusion.cc b/mindspore/lite/tools/optimizer/fusion/onnx_gelu_fusion.cc new file mode 100644 index 0000000000..9fd478d40f --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/onnx_gelu_fusion.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/optimizer/fusion/onnx_gelu_fusion.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr float DIFF_THRESHOLD = 0.0001; +constexpr float DIV_Y = 1.41421; +constexpr float ADD_Y = 1.0; +constexpr float MUL1_y = 0.5; +} // namespace + +// gelu(x) = 1/2 * x * [1 + erf(x / sqrt(2))] +const BaseRef OnnxGeLUFusion::DefinePattern() const { + VectorRef div_ref({std::make_shared(IsSpecifiedNode<&kPrimDivFusion>), input_, div_y_}); + VectorRef erf_ref({std::make_shared(IsSpecifiedNode<&kPrimErf>), div_ref}); + VectorRef add_ref({std::make_shared(IsSpecifiedNode<&prim::kPrimAddFusion>), erf_ref, add_y_}); + VectorRef mul1_ref({std::make_shared(IsSpecifiedNode<&prim::kPrimMulFusion>), input_, mul1_y_}); + VectorRef mul2_ref({std::make_shared(IsSpecifiedNode<&prim::kPrimMulFusion>), mul1_ref, add_ref}); + return mul2_ref; +} + +bool OnnxGeLUFusion::CheckPattern(const EquivPtr &equiv) const { + MS_ASSERT(equiv != nullptr); + float div_y = GetParameterValue(equiv, div_y_); + if (div_y < 0 || fabs(div_y - DIV_Y) > DIFF_THRESHOLD) { + return false; + } + float add_y = GetParameterValue(equiv, add_y_); + if (add_y < 0 || fabs(add_y - ADD_Y) > DIFF_THRESHOLD) { + return false; + } + float mul1_y = GetParameterValue(equiv, mul1_y_); + if (mul1_y < 0 || fabs(mul1_y - MUL1_y) > DIFF_THRESHOLD) { + return false; + } + return true; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/fusion/onnx_gelu_fusion.h b/mindspore/lite/tools/optimizer/fusion/onnx_gelu_fusion.h new file mode 100644 index 0000000000..73e94eb21a --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/onnx_gelu_fusion.h @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_GELU_FUSION_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_GELU_FUSION_H_ + +#include +#include +#include +#include "tools/optimizer/fusion/gelu_fusion.h" + +namespace mindspore { +namespace opt { +class OnnxGeLUFusion : public GeLUFusion { + public: + explicit OnnxGeLUFusion(const std::string &name = "onnx_gelu_fusion", bool multigraph = true) + : GeLUFusion(name, multigraph) { + div_y_ = std::make_shared(); + add_y_ = std::make_shared(); + mul1_y_ = std::make_shared(); + } + ~OnnxGeLUFusion() override = default; + + private: + bool CheckPattern(const EquivPtr &equiv) const override; + const BaseRef DefinePattern() const override; + + private: + VarPtr div_y_ = nullptr; + VarPtr add_y_ = nullptr; + VarPtr mul1_y_ = nullptr; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_GELU_FUSION_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc b/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc index a04ffe7cdc..420615f0a2 100644 --- a/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc @@ -133,7 +133,7 @@ AnfNodePtr TfBidirectionGruFusion::GetCondGraphPattern(const PrimitiveVarMapPtr auto is_less1 = std::make_shared(std::bind(IsOpType, p1, prim::kPrimLess)); auto is_less2 = std::make_shared(std::bind(IsOpType, p1, prim::kPrimLess)); auto is_logical_and = std::make_shared(std::bind(IsOpType, p1, prim::kPrimLogicalAnd)); - auto is_return = std::make_shared(std::bind(IsOpType, p1, kPrimReturn)); + auto is_return = std::make_shared(std::bind(IsOpType, p1, prim::kPrimReturn)); VectorRef less1_ref = VectorRef({is_less1, is_parameter1, is_parameter2}); VectorRef less2_ref = VectorRef({is_less2, is_parameter3, is_parameter4}); VectorRef logicaland_ref = VectorRef({is_logical_and, less1_ref, less2_ref}); @@ -183,13 +183,13 @@ AnfNodePtr TfBidirectionGruFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr VectorRef select_hidden = VectorRef({std::make_shared("Switch"), greater_equal, placeholders[4], new_hidden}); - auto is_make_tuple = std::make_shared(std::bind(IsOpType, p1, kPrimMakeTuple)); + auto is_make_tuple = std::make_shared(std::bind(IsOpType, p1, prim::kPrimMakeTuple)); std::vector outputs = {is_make_tuple, add1, placeholders[1], add, output, select_hidden, placeholders[5], placeholders[6], placeholders[7]}; outputs.insert(outputs.end(), placeholders.begin() + 8, placeholders.end()); VectorRef make_tuple_node = VectorRef(outputs); - auto is_return = std::make_shared(std::bind(IsOpType, p1, kPrimReturn)); + auto is_return = std::make_shared(std::bind(IsOpType, p1, prim::kPrimReturn)); VectorRef return_node = VectorRef({is_return, make_tuple_node}); VarPtr fg = std::make_shared("RootG"); diff --git a/mindspore/lite/tools/optimizer/fusion/tf_gelu_fusion.cc b/mindspore/lite/tools/optimizer/fusion/tf_gelu_fusion.cc new file mode 100644 index 0000000000..3b59b98595 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/tf_gelu_fusion.cc @@ -0,0 +1,88 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/optimizer/fusion/tf_gelu_fusion.h" +#include "ops/op_utils.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr float DIFF_THRESHOLD = 0.0001; +constexpr float POW_Y = 3; +constexpr float MUL1_Y = 0.044715; +constexpr float MUL2_X = 0.79788; +constexpr float ADD2_X = 1.0; +constexpr float MUL3_X = 0.5; +bool CheckTanh(const EquivPtr &equiv, const VarPtr &input) { + MS_ASSERT(equiv != nullptr); + MS_ASSERT(input != nullptr); + auto anf_node = utils::cast((*equiv)[input]); + MS_ASSERT(anf_node != nullptr); + AnfNodePtr value_node = anf_node; + if (anf_node->isa()) { + value_node = anf_node->cast()->input(0); + } + auto act_prim = GetValueNode(value_node); + if (act_prim == nullptr) { + return false; + } + return act_prim->GetAttr(ops::kActivationType) != nullptr && + GetValue(act_prim->GetAttr(ops::kActivationType)) == mindspore::TANH; +} +} // namespace + +// gelu(x) = 1/2 * x * [1 + tanh(0.79788 * (x + 0.044715 * x ^ 3))] +const BaseRef TfGeLUFusion::DefinePattern() const { + VectorRef pow_ref({power_, input_, power_y_}); + VectorRef mul1_ref({std::make_shared(IsSpecifiedNode<&prim::kPrimMulFusion>), mul1_x_, pow_ref}); + VectorRef add1_ref({std::make_shared(IsSpecifiedNode<&prim::kPrimAddFusion>), input_, mul1_ref}); + VectorRef mul2_ref({std::make_shared(IsSpecifiedNode<&prim::kPrimMulFusion>), mul2_x_, add1_ref}); + VectorRef tanh_ref({tanh_, mul2_ref}); + VectorRef add2_ref({std::make_shared(IsSpecifiedNode<&prim::kPrimAddFusion>), add2_x_, tanh_ref}); + VectorRef mul3_ref({std::make_shared(IsSpecifiedNode<&prim::kPrimMulFusion>), mul3_x_, add2_ref}); + VectorRef mul4_ref({std::make_shared(IsSpecifiedNode<&prim::kPrimMulFusion>), input_, mul3_ref}); + return mul4_ref; +} + +bool TfGeLUFusion::CheckPattern(const EquivPtr &equiv) const { + MS_ASSERT(equiv != nullptr); + if (!CheckTanh(equiv, tanh_)) { + return false; + } + float pow_y = GetParameterValue(equiv, power_y_); + if (pow_y < 0 || fabs(pow_y - POW_Y) > DIFF_THRESHOLD) { + return false; + } + float mul1_y = GetParameterValue(equiv, mul1_x_); + if (mul1_y < 0 || fabs(mul1_y - MUL1_Y) > DIFF_THRESHOLD) { + return false; + } + float mul2_x = GetParameterValue(equiv, mul2_x_); + if (mul2_x < 0 || fabs(mul2_x - MUL2_X) > DIFF_THRESHOLD) { + return false; + } + float add2_x = GetParameterValue(equiv, add2_x_); + if (add2_x < 0 || fabs(add2_x - ADD2_X) > DIFF_THRESHOLD) { + return false; + } + float mul3_x = GetParameterValue(equiv, mul3_x_); + if (mul3_x < 0 || fabs(mul3_x - MUL3_X) > DIFF_THRESHOLD) { + return false; + } + return true; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/fusion/tf_gelu_fusion.h b/mindspore/lite/tools/optimizer/fusion/tf_gelu_fusion.h new file mode 100644 index 0000000000..b6ba5eb0a7 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/tf_gelu_fusion.h @@ -0,0 +1,57 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_GELU_FUSION_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_GELU_FUSION_H_ + +#include +#include +#include +#include "tools/optimizer/fusion/gelu_fusion.h" + +namespace mindspore { +namespace opt { +class TfGeLUFusion : public GeLUFusion { + public: + explicit TfGeLUFusion(const std::string &name = "tf_gelu_fusion", bool multigraph = true) + : GeLUFusion(name, multigraph) { + power_ = std::make_shared(); + power_y_ = std::make_shared(); + mul1_x_ = std::make_shared(); + mul2_x_ = std::make_shared(); + tanh_ = std::make_shared(); + add2_x_ = std::make_shared(); + mul3_x_ = std::make_shared(); + } + ~TfGeLUFusion() override = default; + + private: + bool CheckPattern(const EquivPtr &equiv) const override; + const BaseRef DefinePattern() const override; + + private: + VarPtr power_ = nullptr; + VarPtr power_y_ = nullptr; + VarPtr mul1_x_ = nullptr; + VarPtr mul2_x_ = nullptr; + VarPtr tanh_ = nullptr; + VarPtr add2_x_ = nullptr; + VarPtr mul3_x_ = nullptr; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_GELU_FUSION_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/tf_lstm_cell_fusion.cc b/mindspore/lite/tools/optimizer/fusion/tf_lstm_cell_fusion.cc index 59203b52aa..775da6c4d9 100644 --- a/mindspore/lite/tools/optimizer/fusion/tf_lstm_cell_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/tf_lstm_cell_fusion.cc @@ -98,11 +98,11 @@ AnfNodePtr TfLstmCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primi VectorRef set_item = VectorRef({std::make_shared(""), placeholders[3], placeholders[2], new_hidden}); - auto is_make_tuple = std::make_shared(std::bind(IsOpType, p1, kPrimMakeTuple)); + auto is_make_tuple = std::make_shared(std::bind(IsOpType, p1, prim::kPrimMakeTuple)); std::vector outputs = {is_make_tuple, add3, placeholders[1], add2, set_item, output_cell, output_hidden}; outputs.insert(outputs.end(), placeholders.begin() + 6, placeholders.end()); VectorRef make_tuple_node = VectorRef(outputs); - auto is_return = std::make_shared(std::bind(IsOpType, p1, kPrimReturn)); + auto is_return = std::make_shared(std::bind(IsOpType, p1, prim::kPrimReturn)); VectorRef return_node = VectorRef({is_return, make_tuple_node}); VarPtr fg = std::make_shared("RootG"); diff --git a/mindspore/lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc b/mindspore/lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc index 0b13d4bcb3..d7d65a1b12 100644 --- a/mindspore/lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc @@ -116,7 +116,7 @@ AnfNodePtr TfliteLstmCellFusion::GetCondGraphPattern(const PrimitiveVarMapPtr &p auto is_less1 = std::make_shared(std::bind(IsOpType, p1, prim::kPrimLess)); auto is_less2 = std::make_shared(std::bind(IsOpType, p1, prim::kPrimLess)); auto is_logical_and = std::make_shared(std::bind(IsOpType, p1, prim::kPrimLogicalAnd)); - auto is_return = std::make_shared(std::bind(IsOpType, p1, kPrimReturn)); + auto is_return = std::make_shared(std::bind(IsOpType, p1, prim::kPrimReturn)); VectorRef less1_ref = VectorRef({is_less1, is_parameter1, is_parameter2}); VectorRef less2_ref = VectorRef({is_less2, is_parameter3, is_parameter4}); VectorRef logicaland_ref = VectorRef({is_logical_and, less1_ref, less2_ref}); @@ -174,11 +174,11 @@ AnfNodePtr TfliteLstmCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &p VectorRef set_item = VectorRef({std::make_shared("SetItem"), placeholders[3], placeholders[2], output}); - auto is_make_tuple = std::make_shared(std::bind(IsOpType, p1, kPrimMakeTuple)); + auto is_make_tuple = std::make_shared(std::bind(IsOpType, p1, prim::kPrimMakeTuple)); std::vector outputs = {is_make_tuple, add3, placeholders[1], add2, set_item, cell_output, hidden_output}; outputs.insert(outputs.end(), placeholders.begin() + 6, placeholders.end()); VectorRef make_tuple_node = VectorRef(outputs); - auto is_return = std::make_shared(std::bind(IsOpType, p1, kPrimReturn)); + auto is_return = std::make_shared(std::bind(IsOpType, p1, prim::kPrimReturn)); VectorRef return_node = VectorRef({is_return, make_tuple_node}); VarPtr fg = std::make_shared("RootG"); diff --git a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc index 94471e9545..fd4fa7e7c3 100644 --- a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc @@ -41,8 +41,8 @@ ParamValueLitePtr NewParamValueLitePtr(lite::Tensor *tensor) { bool IsSpecialType(const CNodePtr &cnode) { if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || CheckPrimitiveType(cnode, prim::kPrimDepend) || - CheckPrimitiveType(cnode, prim::kPrimControlDepend) || CheckPrimitiveType(cnode, kPrimMakeTuple) || - CheckPrimitiveType(cnode, kPrimReturn) || CheckPrimitiveType(cnode, std::make_shared("While")) || + CheckPrimitiveType(cnode, prim::kPrimControlDepend) || CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || + CheckPrimitiveType(cnode, prim::kPrimReturn) || CheckPrimitiveType(cnode, std::make_shared("While")) || CheckPrimitiveType(cnode, std::make_shared("If"))) { return true; } diff --git a/mindspore/lite/tools/optimizer/graph/while_pass.cc b/mindspore/lite/tools/optimizer/graph/while_pass.cc index 93ac70a3bb..b492726ebf 100644 --- a/mindspore/lite/tools/optimizer/graph/while_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/while_pass.cc @@ -81,7 +81,7 @@ bool WhilePass::Run(const FuncGraphPtr &graph) { // concat body to cond std::vector body_to_cond_inputs{cond_vnode}; - if (CheckPrimitiveType(body_output_cnode, kPrimMakeTuple)) { + if (CheckPrimitiveType(body_output_cnode, prim::kPrimMakeTuple)) { for (size_t i = 1; i < body_output_cnode->inputs().size(); ++i) { body_to_cond_inputs.emplace_back(body_output_cnode->input(i)); }