From: @xu_anyue Reviewed-by: @hangangqiang,@HilbertDavid Signed-off-by: @hangangqiangpull/13996/MERGE
| @@ -76,18 +76,12 @@ int SliceCPUKernel::SliceParallelRun(int thread_id) { | |||
| } | |||
| int SliceCPUKernel::Run() { | |||
| auto ret = PreProcess(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "PreProcess fail!ret: " << ret; | |||
| return ret; | |||
| } | |||
| if (param_->size_[1] < op_parameter_->thread_num_) { | |||
| DoSliceNoParallel(in_tensors_.at(0)->data_c(), out_tensors_.at(0)->data_c(), param_, | |||
| lite::DataTypeSize(in_tensors_.at(0)->data_type())); | |||
| return RET_OK; | |||
| } | |||
| ret = ParallelLaunch(this->context_->thread_pool_, SliceLaunch, this, op_parameter_->thread_num_); | |||
| auto ret = ParallelLaunch(this->context_->thread_pool_, SliceLaunch, this, op_parameter_->thread_num_); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "slice launch fail!ret: " << ret; | |||
| return RET_ERROR; | |||
| @@ -96,6 +90,5 @@ int SliceCPUKernel::Run() { | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_SliceFusion, LiteKernelCreator<SliceCPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_SliceFusion, LiteKernelCreator<SliceCPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SliceFusion, LiteKernelCreator<SliceCPUKernel>) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,75 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/fp16/slice_fp16.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "nnacl/base/slice_base.h" | |||
| #include "nnacl/fp16/cast_fp16.h" | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_SliceFusion; | |||
| namespace mindspore::kernel { | |||
| int SliceFp16Launch(void *cdata, int task_id) { | |||
| if (cdata == nullptr) { | |||
| MS_LOG(ERROR) << "Input cdata is nullptr!"; | |||
| return RET_ERROR; | |||
| } | |||
| auto kernel = reinterpret_cast<SliceFp16CPUKernel *>(cdata); | |||
| return kernel->SliceFp16ParallelRun(task_id); | |||
| } | |||
| SliceFp16CPUKernel::~SliceFp16CPUKernel() { | |||
| if (input_data_ != nullptr) { | |||
| context_->allocator->Free(input_data_); | |||
| input_data_ = nullptr; | |||
| } | |||
| } | |||
| int SliceFp16CPUKernel::Init() { | |||
| auto input_tensor = in_tensors_.at(0); | |||
| if (input_tensor->data_type() == kNumberTypeFloat32 && input_tensor->data_c() != nullptr) { | |||
| input_data_ = | |||
| reinterpret_cast<float16_t *>(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t))); | |||
| Float32ToFloat16(reinterpret_cast<float *>(input_tensor->data_c()), input_data_, input_tensor->ElementsNum()); | |||
| } | |||
| return SliceCPUKernel::Init(); | |||
| } | |||
| int SliceFp16CPUKernel::SliceFp16ParallelRun(int thread_id) { | |||
| void *input_data = input_data_ == nullptr ? in_tensors_.at(0)->data_c() : input_data_; | |||
| DoSlice(input_data, out_tensors_.at(0)->data_c(), param_, thread_id, lite::DataTypeSize(kNumberTypeFloat16)); | |||
| return RET_OK; | |||
| } | |||
| int SliceFp16CPUKernel::Run() { | |||
| void *input_data = input_data_ == nullptr ? in_tensors_.at(0)->data_c() : input_data_; | |||
| if (param_->size_[1] < op_parameter_->thread_num_) { | |||
| DoSliceNoParallel(input_data, out_tensors_.at(0)->data_c(), param_, lite::DataTypeSize(kNumberTypeFloat16)); | |||
| return RET_OK; | |||
| } | |||
| auto ret = ParallelLaunch(this->context_->thread_pool_, SliceFp16Launch, this, op_parameter_->thread_num_); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "fp16 slice launch fail!ret: " << ret; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_SliceFusion, LiteKernelCreator<SliceFp16CPUKernel>) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SLICE_FP16_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SLICE_FP16_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "src/runtime/kernel/arm/base/slice_base.h" | |||
| namespace mindspore::kernel { | |||
| class SliceFp16CPUKernel : public SliceCPUKernel { | |||
| public: | |||
| SliceFp16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) | |||
| : SliceCPUKernel(parameter, inputs, outputs, ctx) {} | |||
| ~SliceFp16CPUKernel() override; | |||
| int Init() override; | |||
| int Run() override; | |||
| int SliceFp16ParallelRun(int thread_id); | |||
| private: | |||
| float16_t *input_data_ = nullptr; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SLICE_FP16_H_ | |||
| @@ -243,6 +243,9 @@ if(ENABLE_CONVERTER) | |||
| ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/matmul_add_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/gelu_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/tf_gelu_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/onnx_gelu_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc | |||
| @@ -98,7 +98,7 @@ void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { | |||
| MS_LOG(ERROR) << "value node is invalid."; | |||
| return; | |||
| } | |||
| if (value_node->value() != nullptr && (opt::CheckPrimitiveType(make_tuple_node, opt::kPrimMakeTuple) || | |||
| if (value_node->value() != nullptr && (opt::CheckPrimitiveType(make_tuple_node, prim::kPrimMakeTuple) || | |||
| opt::CheckPrimitiveType(make_tuple_node, opt::kPrimMakeTupleV2))) { | |||
| has_make_tuple = true; | |||
| for (size_t j = 1; j < make_tuple_node->inputs().size(); ++j) { | |||
| @@ -372,7 +372,7 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc | |||
| ret = RET_MEMORY_FAILED; | |||
| break; | |||
| } | |||
| if (opt::CheckPrimitiveType(cnode, opt::kPrimReturn)) { | |||
| if (opt::CheckPrimitiveType(cnode, prim::kPrimReturn)) { | |||
| node->name = mindspore::ops::kNameReturn; | |||
| ret = SetGraphoutputIndex(cnode, subgraph_index, meta_graphT, sub_graphT, node.get()); | |||
| if (ret != RET_OK) { | |||
| @@ -53,6 +53,9 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ../optimizer/fusion/tf_bidirection_gru_fusion.cc | |||
| ../optimizer/fusion/tf_bidirection_gru_cf_fusion.cc | |||
| ../optimizer/fusion/matmul_add_fusion.cc | |||
| ../optimizer/fusion/gelu_fusion.cc | |||
| ../optimizer/fusion/tf_gelu_fusion.cc | |||
| ../optimizer/fusion/onnx_gelu_fusion.cc | |||
| ../optimizer/graph/weight_format_transform_pass.cc | |||
| ../optimizer/graph/weight_format_hardcode_pass.cc | |||
| ../optimizer/graph/clip_convert_activation_pass.cc | |||
| @@ -37,6 +37,8 @@ | |||
| #include "tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.h" | |||
| #include "tools/optimizer/fusion/matmul_add_fusion.h" | |||
| #include "tools/optimizer/graph/primitive_adjust_pass.h" | |||
| #include "tools/optimizer/fusion/tf_gelu_fusion.h" | |||
| #include "tools/optimizer/fusion/onnx_gelu_fusion.h" | |||
| #include "tools/optimizer/graph/mindir_adjust_pass.h" | |||
| #include "tools/optimizer/graph/redundant_op_remove_pass.h" | |||
| #include "tools/optimizer/graph/weight_format_hardcode_pass.h" | |||
| @@ -89,6 +91,8 @@ int AnfTransform::AddFusionPass(const std::shared_ptr<opt::GraphOptimizer> &opti | |||
| fusion_pm->AddPass(std::make_shared<opt::TfliteLstmCellFusion>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::TfLstmCellFusion>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::TfBidirectionGruFusion>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::TfGeLUFusion>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::OnnxGeLUFusion>()); | |||
| } | |||
| if (config->fmk == lite::converter::FmkType_MS) { | |||
| auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>(); | |||
| @@ -54,10 +54,10 @@ bool IsRealKernel(const AnfNodePtr &node) { | |||
| auto input = cnode->inputs()[0]; | |||
| bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) || | |||
| IsPrimitive(input, prim::kPrimTensorSummary) || | |||
| IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, kPrimMakeTuple) || | |||
| IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) || | |||
| IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) || | |||
| IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) || | |||
| IsPrimitive(input, kPrimReturn) || IsPrimitive(input, prim::kPrimPartial); | |||
| IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial); | |||
| return !is_virtual_node; | |||
| } | |||
| @@ -335,7 +335,7 @@ bool IsRealCNodeKernel(const AnfNodePtr &node) { | |||
| return false; | |||
| } | |||
| // return considered as a real node | |||
| if (CheckPrimitiveType(node, kPrimReturn)) { | |||
| if (CheckPrimitiveType(node, prim::kPrimReturn)) { | |||
| return true; | |||
| } | |||
| return IsRealKernel(node); | |||
| @@ -35,8 +35,8 @@ using mindspore::lite::RET_OK; | |||
| using mindspore::lite::STATUS; | |||
| namespace mindspore { | |||
| namespace opt { | |||
| inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("Return"); | |||
| inline const PrimitivePtr kPrimMakeTuple = std::make_shared<Primitive>("MakeTuple"); | |||
| inline const PrimitivePtr kPrimDivFusion = std::make_shared<Primitive>("DivFusion"); | |||
| inline const PrimitivePtr kPrimErf = std::make_shared<Primitive>("Erf"); | |||
| inline const PrimitivePtr kPrimMakeTupleV2 = std::make_shared<Primitive>("make_tuple"); | |||
| inline const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("Identity"); | |||
| std::vector<int> CastToInt(const ValuePtr &value); | |||
| @@ -145,6 +145,15 @@ ParameterPtr BuildIntVec2DParameterNode(const FuncGraphPtr &func_graph, const st | |||
| ParameterPtr BuildFloatValueParameterNode(const FuncGraphPtr &func_graph, const float &data, | |||
| const std::string &node_name); | |||
| template <const PrimitivePtr *prim = nullptr> | |||
| inline bool IsSpecifiedNode(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| auto anf_node = utils::cast<AnfNodePtr>(n); | |||
| return CheckPrimitiveType(anf_node, *prim); | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_GLLO_UTILS_H_ | |||
| @@ -0,0 +1,85 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/fusion/gelu_fusion.h" | |||
| #include <memory> | |||
| #include <string> | |||
| #include "ops/fusion/activation.h" | |||
| #include "utils/utils.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| CNodePtr GeLUFusion::CreateGeLUNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &equiv) const { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| auto gelu_prim = std::make_shared<ops::Activation>(); | |||
| gelu_prim->set_activation_type(mindspore::GELU); | |||
| auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_]); | |||
| MS_ASSERT(input_node != nullptr); | |||
| auto gelu_cnode = func_graph->NewCNode(gelu_prim, {input_node}); | |||
| gelu_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_gelu"); | |||
| gelu_cnode->set_abstract(node->abstract()->Clone()); | |||
| return gelu_cnode; | |||
| } | |||
| const float GeLUFusion::GetParameterValue(const EquivPtr &equiv, const VarPtr &input) const { | |||
| MS_ASSERT(equiv != nullptr); | |||
| MS_ASSERT(input != nullptr); | |||
| float value = -1; | |||
| auto node = utils::cast<AnfNodePtr>((*equiv)[input]); | |||
| if (node == nullptr || !utils::isa<ParameterPtr>(node)) { | |||
| return value; | |||
| } | |||
| auto parameter_node = node->cast<ParameterPtr>(); | |||
| if (!parameter_node->has_default() || parameter_node->default_param() == nullptr) { | |||
| return value; | |||
| } | |||
| auto param_value_lite = parameter_node->default_param()->cast<ParamValueLitePtr>(); | |||
| if (param_value_lite == nullptr) { | |||
| return value; | |||
| } | |||
| if (param_value_lite->tensor_type() != kNumberTypeFloat32 && param_value_lite->tensor_type() != kNumberTypeFloat) { | |||
| return value; | |||
| } | |||
| if (param_value_lite->tensor_size() != sizeof(float)) { | |||
| return value; | |||
| } | |||
| return *static_cast<float *>(param_value_lite->tensor_addr()); | |||
| } | |||
| const AnfNodePtr GeLUFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &equiv) const { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| MS_ASSERT(equiv != nullptr); | |||
| MS_LOG(DEBUG) << "gelu_fusion pass"; | |||
| if (!utils::isa<CNodePtr>(node)) { | |||
| return nullptr; | |||
| } | |||
| if (!CheckPattern(equiv)) { | |||
| return nullptr; | |||
| } | |||
| auto cnode = CreateGeLUNode(func_graph, node, equiv); | |||
| if (cnode == nullptr) { | |||
| MS_LOG(DEBUG) << "new gelu node failed."; | |||
| return nullptr; | |||
| } | |||
| return cnode; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,48 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_GELU_FUSION_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_GELU_FUSION_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "utils/utils.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class GeLUFusion : public PatternProcessPass { | |||
| public: | |||
| explicit GeLUFusion(const std::string &name = "gelu_fusion", bool multigraph = true) | |||
| : PatternProcessPass(name, multigraph), input_(std::make_shared<Var>()) {} | |||
| ~GeLUFusion() override = default; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| protected: | |||
| virtual bool CheckPattern(const EquivPtr &equiv) const = 0; | |||
| const float GetParameterValue(const EquivPtr &equiv, const VarPtr &input) const; | |||
| VarPtr input_ = nullptr; | |||
| private: | |||
| CNodePtr CreateGeLUNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv) const; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_GELU_FUSION_H_ | |||
| @@ -20,19 +20,19 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr size_t AddInputSize = 3; | |||
| constexpr size_t MatMulInputSize = 3; | |||
| constexpr size_t kAddInputSize = 3; | |||
| constexpr size_t kMatMulInputSize = 3; | |||
| bool CheckAndGetMatMulIndex(const CNodePtr &cnode, size_t *index) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| MS_ASSERT(index != nullptr); | |||
| if (cnode->size() != AddInputSize) { | |||
| if (cnode->size() != kAddInputSize) { | |||
| return false; | |||
| } | |||
| size_t matmul_index = 0; | |||
| for (size_t i = 1; i < cnode->size(); ++i) { | |||
| if (CheckPrimitiveType(cnode->input(i), prim::kPrimMatMul)) { | |||
| auto matmul_cnode = cnode->input(i)->cast<CNodePtr>(); | |||
| if (matmul_cnode->size() > MatMulInputSize) { | |||
| if (matmul_cnode->size() > kMatMulInputSize) { | |||
| continue; | |||
| } | |||
| matmul_index = i; | |||
| @@ -63,7 +63,7 @@ bool MatMulAddFusion::Run(const FuncGraphPtr &func_graph) { | |||
| continue; | |||
| } | |||
| auto matmul_cnode = cnode->input(index)->cast<CNodePtr>(); | |||
| auto bias_node = cnode->input(AddInputSize - index); | |||
| auto bias_node = cnode->input(kAddInputSize - index); | |||
| if (!utils::isa<Parameter>(bias_node) || !bias_node->cast<ParameterPtr>()->default_param()) { | |||
| continue; | |||
| } | |||
| @@ -17,7 +17,6 @@ | |||
| #include <memory> | |||
| #include "ops/fusion/layer_norm_fusion.h" | |||
| #include "ops/fusion/reduce_fusion.h" | |||
| #include "ops/rsqrt.h" | |||
| #include "mindspore/core/ops/instance_norm.h" | |||
| #include "src/param_value_lite.h" | |||
| #include "utils/utils.h" | |||
| @@ -27,60 +26,6 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| inline bool IsAddNode(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimAddFusion); | |||
| } | |||
| return false; | |||
| } | |||
| inline bool IsSquaredDifferenceNode(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSquaredDifference); | |||
| } | |||
| return false; | |||
| } | |||
| inline bool IsRsqrtNode(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimRsqrt); | |||
| } | |||
| return false; | |||
| } | |||
| inline bool IsMulNode(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimMulFusion); | |||
| } | |||
| return false; | |||
| } | |||
| inline bool IsSubNode(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSubFusion); | |||
| } | |||
| return false; | |||
| } | |||
| inline bool IsPowNode(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimPowFusion); | |||
| } | |||
| return false; | |||
| } | |||
| inline bool IsSqrtNode(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSqrt); | |||
| } | |||
| return false; | |||
| } | |||
| inline bool IsDivNode(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimDiv) || | |||
| CheckPrimitiveType(utils::cast<AnfNodePtr>(n), std::make_shared<Primitive>("DivFusion")); | |||
| } | |||
| return false; | |||
| } | |||
| STATUS GetReduceAxes(const BaseRef &n, std::vector<int> *axes) { | |||
| MS_ASSERT(node != nullptr); | |||
| if (utils::isa<ParameterPtr>(n)) { | |||
| @@ -195,7 +140,7 @@ bool NormFusion::GetNormTypeAndAxis(const CNodePtr &input_cnode, const std::vect | |||
| } | |||
| } | |||
| if (mean_axes.back() >= 0 && mean_axes.back() + 1 != static_cast<int>(shape.size())) { | |||
| MS_LOG(DEBUG) << "mean node is not reduce to last axis"; | |||
| MS_LOG(DEBUG) << "mean node is not reduce to last axis."; | |||
| return false; | |||
| } | |||
| @@ -318,37 +263,41 @@ const AnfNodePtr NormFusion::Process(const FuncGraphPtr &func_graph, const AnfNo | |||
| const BaseRef TfNormFusion::DefinePattern() const { | |||
| VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_}); | |||
| auto squared_diffference1 = std::make_shared<CondVar>(IsSquaredDifferenceNode); | |||
| auto squared_diffference1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSquaredDifference>); | |||
| VectorRef squared_diffference1_ref = VectorRef({squared_diffference1, input_, mean1_ref}); | |||
| auto mul1 = std::make_shared<CondVar>(IsMulNode); | |||
| auto mul1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>); | |||
| VectorRef mean2_ref = VectorRef({mean2_, squared_diffference1_ref, mean2_axes_}); | |||
| auto add1 = std::make_shared<CondVar>(IsAddNode); | |||
| auto add1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>); | |||
| VectorRef add1_ref = VectorRef({add1, mean2_ref, epsilon_}); | |||
| auto rsqrt1 = std::make_shared<CondVar>(IsRsqrtNode); | |||
| auto rsqrt1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimRsqrt>); | |||
| VectorRef rsqrt1_ref = VectorRef({rsqrt1, add1_ref}); | |||
| auto mul2 = std::make_shared<CondVar>(IsMulNode); | |||
| auto mul2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>); | |||
| VectorRef mul2_ref = VectorRef({mul2, rsqrt1_ref, gamma_}); | |||
| VectorRef mul1_ref = VectorRef({mul1, input_, mul2_ref}); | |||
| auto mul3 = std::make_shared<CondVar>(IsMulNode); | |||
| auto mul3 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>); | |||
| VectorRef mul3_ref = VectorRef({mul3, mean1_ref, mul2_ref}); | |||
| auto sub1 = std::make_shared<CondVar>(IsSubNode); | |||
| auto sub1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSubFusion>); | |||
| VectorRef sub1_ref = VectorRef({sub1, beta_, mul3_ref}); | |||
| auto add2 = std::make_shared<CondVar>(IsAddNode); | |||
| auto add2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>); | |||
| VectorRef add2_ref = VectorRef({add2, mul1_ref, sub1_ref}); | |||
| return add2_ref; | |||
| } | |||
| const BaseRef OnnxLayerNormFusion::DefinePattern() const { | |||
| VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_}); | |||
| VectorRef sub1_ref = VectorRef({std::make_shared<CondVar>(IsSubNode), input_, mean1_ref}); | |||
| VectorRef sub2_ref = VectorRef({std::make_shared<CondVar>(IsSubNode), input_, mean1_ref}); | |||
| VectorRef pow_ref = VectorRef({std::make_shared<CondVar>(IsPowNode), sub2_ref, std::make_shared<Var>()}); | |||
| VectorRef sub1_ref = | |||
| VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSubFusion>), input_, mean1_ref}); | |||
| VectorRef sub2_ref = | |||
| VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSubFusion>), input_, mean1_ref}); | |||
| VectorRef pow_ref = | |||
| VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimPowFusion>), sub2_ref, std::make_shared<Var>()}); | |||
| VectorRef mean2_ref = VectorRef({mean2_, pow_ref, mean2_axes_}); | |||
| VectorRef add1_ref = VectorRef({std::make_shared<CondVar>(IsAddNode), mean2_ref, epsilon_}); | |||
| VectorRef sqrt_ref = VectorRef({std::make_shared<CondVar>(IsSqrtNode), add1_ref}); | |||
| VectorRef div_ref = VectorRef({std::make_shared<CondVar>(IsDivNode), sub1_ref, sqrt_ref}); | |||
| VectorRef mul_ref = VectorRef({std::make_shared<CondVar>(IsMulNode), gamma_, div_ref}); | |||
| VectorRef add2_ref = VectorRef({std::make_shared<CondVar>(IsAddNode), mul_ref, beta_}); | |||
| VectorRef add1_ref = | |||
| VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), mean2_ref, epsilon_}); | |||
| VectorRef sqrt_ref = VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSqrt>), add1_ref}); | |||
| VectorRef div_ref = VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&kPrimDivFusion>), sub1_ref, sqrt_ref}); | |||
| VectorRef mul_ref = VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), gamma_, div_ref}); | |||
| VectorRef add2_ref = VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), mul_ref, beta_}); | |||
| return add2_ref; | |||
| } | |||
| } // namespace opt | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_NORM_FUSION_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_NORM_FUSION_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| @@ -31,7 +31,7 @@ namespace opt { | |||
| /// fuse layer_norm or instance_norm into one operator | |||
| class NormFusion : public PatternProcessPass { | |||
| public: | |||
| explicit NormFusion(const std::string &name = "tf_norm_fusion", bool multigraph = true) | |||
| explicit NormFusion(const std::string &name = "norm_fusion", bool multigraph = true) | |||
| : PatternProcessPass(name, multigraph) { | |||
| input_ = std::make_shared<Var>(); | |||
| mean1_ = std::make_shared<Var>(); | |||
| @@ -44,7 +44,6 @@ class NormFusion : public PatternProcessPass { | |||
| } | |||
| ~NormFusion() override = default; | |||
| virtual const BaseRef DefinePattern() const = 0; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| @@ -70,6 +69,9 @@ class NormFusion : public PatternProcessPass { | |||
| /// fuse tf layer_norm or instance_norm into one operator | |||
| class TfNormFusion : public NormFusion { | |||
| public: | |||
| explicit TfNormFusion(const std::string &name = "tf_norm_fusion", bool multigraph = true) | |||
| : NormFusion(name, multigraph) {} | |||
| ~TfNormFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| }; | |||
| @@ -77,11 +79,13 @@ class TfNormFusion : public NormFusion { | |||
| /// fuse onnx layer_norm into one operator | |||
| class OnnxLayerNormFusion : public NormFusion { | |||
| public: | |||
| explicit OnnxLayerNormFusion(const std::string &name = "onnx_layer_norm_fusion", bool multigraph = true) | |||
| : NormFusion(name, multigraph) {} | |||
| ~OnnxLayerNormFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_ | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_NORM_FUSION_H_ | |||
| @@ -0,0 +1,55 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/fusion/onnx_gelu_fusion.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr float DIFF_THRESHOLD = 0.0001; | |||
| constexpr float DIV_Y = 1.41421; | |||
| constexpr float ADD_Y = 1.0; | |||
| constexpr float MUL1_y = 0.5; | |||
| } // namespace | |||
| // gelu(x) = 1/2 * x * [1 + erf(x / sqrt(2))] | |||
| const BaseRef OnnxGeLUFusion::DefinePattern() const { | |||
| VectorRef div_ref({std::make_shared<CondVar>(IsSpecifiedNode<&kPrimDivFusion>), input_, div_y_}); | |||
| VectorRef erf_ref({std::make_shared<CondVar>(IsSpecifiedNode<&kPrimErf>), div_ref}); | |||
| VectorRef add_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), erf_ref, add_y_}); | |||
| VectorRef mul1_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), input_, mul1_y_}); | |||
| VectorRef mul2_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), mul1_ref, add_ref}); | |||
| return mul2_ref; | |||
| } | |||
| bool OnnxGeLUFusion::CheckPattern(const EquivPtr &equiv) const { | |||
| MS_ASSERT(equiv != nullptr); | |||
| float div_y = GetParameterValue(equiv, div_y_); | |||
| if (div_y < 0 || fabs(div_y - DIV_Y) > DIFF_THRESHOLD) { | |||
| return false; | |||
| } | |||
| float add_y = GetParameterValue(equiv, add_y_); | |||
| if (add_y < 0 || fabs(add_y - ADD_Y) > DIFF_THRESHOLD) { | |||
| return false; | |||
| } | |||
| float mul1_y = GetParameterValue(equiv, mul1_y_); | |||
| if (mul1_y < 0 || fabs(mul1_y - MUL1_y) > DIFF_THRESHOLD) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_GELU_FUSION_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_GELU_FUSION_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "tools/optimizer/fusion/gelu_fusion.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class OnnxGeLUFusion : public GeLUFusion { | |||
| public: | |||
| explicit OnnxGeLUFusion(const std::string &name = "onnx_gelu_fusion", bool multigraph = true) | |||
| : GeLUFusion(name, multigraph) { | |||
| div_y_ = std::make_shared<Var>(); | |||
| add_y_ = std::make_shared<Var>(); | |||
| mul1_y_ = std::make_shared<Var>(); | |||
| } | |||
| ~OnnxGeLUFusion() override = default; | |||
| private: | |||
| bool CheckPattern(const EquivPtr &equiv) const override; | |||
| const BaseRef DefinePattern() const override; | |||
| private: | |||
| VarPtr div_y_ = nullptr; | |||
| VarPtr add_y_ = nullptr; | |||
| VarPtr mul1_y_ = nullptr; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_GELU_FUSION_H_ | |||
| @@ -133,7 +133,7 @@ AnfNodePtr TfBidirectionGruFusion::GetCondGraphPattern(const PrimitiveVarMapPtr | |||
| auto is_less1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess)); | |||
| auto is_less2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess)); | |||
| auto is_logical_and = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLogicalAnd)); | |||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, kPrimReturn)); | |||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn)); | |||
| VectorRef less1_ref = VectorRef({is_less1, is_parameter1, is_parameter2}); | |||
| VectorRef less2_ref = VectorRef({is_less2, is_parameter3, is_parameter4}); | |||
| VectorRef logicaland_ref = VectorRef({is_logical_and, less1_ref, less2_ref}); | |||
| @@ -183,13 +183,13 @@ AnfNodePtr TfBidirectionGruFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr | |||
| VectorRef select_hidden = VectorRef({std::make_shared<Var>("Switch"), greater_equal, placeholders[4], new_hidden}); | |||
| auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, kPrimMakeTuple)); | |||
| auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMakeTuple)); | |||
| std::vector<BaseRef> outputs = {is_make_tuple, add1, placeholders[1], add, | |||
| output, select_hidden, placeholders[5], placeholders[6], | |||
| placeholders[7]}; | |||
| outputs.insert(outputs.end(), placeholders.begin() + 8, placeholders.end()); | |||
| VectorRef make_tuple_node = VectorRef(outputs); | |||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, kPrimReturn)); | |||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn)); | |||
| VectorRef return_node = VectorRef({is_return, make_tuple_node}); | |||
| VarPtr fg = std::make_shared<Var>("RootG"); | |||
| @@ -0,0 +1,88 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/fusion/tf_gelu_fusion.h" | |||
| #include "ops/op_utils.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr float DIFF_THRESHOLD = 0.0001; | |||
| constexpr float POW_Y = 3; | |||
| constexpr float MUL1_Y = 0.044715; | |||
| constexpr float MUL2_X = 0.79788; | |||
| constexpr float ADD2_X = 1.0; | |||
| constexpr float MUL3_X = 0.5; | |||
| bool CheckTanh(const EquivPtr &equiv, const VarPtr &input) { | |||
| MS_ASSERT(equiv != nullptr); | |||
| MS_ASSERT(input != nullptr); | |||
| auto anf_node = utils::cast<AnfNodePtr>((*equiv)[input]); | |||
| MS_ASSERT(anf_node != nullptr); | |||
| AnfNodePtr value_node = anf_node; | |||
| if (anf_node->isa<CNode>()) { | |||
| value_node = anf_node->cast<CNodePtr>()->input(0); | |||
| } | |||
| auto act_prim = GetValueNode<PrimitivePtr>(value_node); | |||
| if (act_prim == nullptr) { | |||
| return false; | |||
| } | |||
| return act_prim->GetAttr(ops::kActivationType) != nullptr && | |||
| GetValue<int64_t>(act_prim->GetAttr(ops::kActivationType)) == mindspore::TANH; | |||
| } | |||
| } // namespace | |||
| // gelu(x) = 1/2 * x * [1 + tanh(0.79788 * (x + 0.044715 * x ^ 3))] | |||
| const BaseRef TfGeLUFusion::DefinePattern() const { | |||
| VectorRef pow_ref({power_, input_, power_y_}); | |||
| VectorRef mul1_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), mul1_x_, pow_ref}); | |||
| VectorRef add1_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), input_, mul1_ref}); | |||
| VectorRef mul2_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), mul2_x_, add1_ref}); | |||
| VectorRef tanh_ref({tanh_, mul2_ref}); | |||
| VectorRef add2_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), add2_x_, tanh_ref}); | |||
| VectorRef mul3_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), mul3_x_, add2_ref}); | |||
| VectorRef mul4_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), input_, mul3_ref}); | |||
| return mul4_ref; | |||
| } | |||
| bool TfGeLUFusion::CheckPattern(const EquivPtr &equiv) const { | |||
| MS_ASSERT(equiv != nullptr); | |||
| if (!CheckTanh(equiv, tanh_)) { | |||
| return false; | |||
| } | |||
| float pow_y = GetParameterValue(equiv, power_y_); | |||
| if (pow_y < 0 || fabs(pow_y - POW_Y) > DIFF_THRESHOLD) { | |||
| return false; | |||
| } | |||
| float mul1_y = GetParameterValue(equiv, mul1_x_); | |||
| if (mul1_y < 0 || fabs(mul1_y - MUL1_Y) > DIFF_THRESHOLD) { | |||
| return false; | |||
| } | |||
| float mul2_x = GetParameterValue(equiv, mul2_x_); | |||
| if (mul2_x < 0 || fabs(mul2_x - MUL2_X) > DIFF_THRESHOLD) { | |||
| return false; | |||
| } | |||
| float add2_x = GetParameterValue(equiv, add2_x_); | |||
| if (add2_x < 0 || fabs(add2_x - ADD2_X) > DIFF_THRESHOLD) { | |||
| return false; | |||
| } | |||
| float mul3_x = GetParameterValue(equiv, mul3_x_); | |||
| if (mul3_x < 0 || fabs(mul3_x - MUL3_X) > DIFF_THRESHOLD) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,57 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_GELU_FUSION_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_GELU_FUSION_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "tools/optimizer/fusion/gelu_fusion.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class TfGeLUFusion : public GeLUFusion { | |||
| public: | |||
| explicit TfGeLUFusion(const std::string &name = "tf_gelu_fusion", bool multigraph = true) | |||
| : GeLUFusion(name, multigraph) { | |||
| power_ = std::make_shared<Var>(); | |||
| power_y_ = std::make_shared<Var>(); | |||
| mul1_x_ = std::make_shared<Var>(); | |||
| mul2_x_ = std::make_shared<Var>(); | |||
| tanh_ = std::make_shared<Var>(); | |||
| add2_x_ = std::make_shared<Var>(); | |||
| mul3_x_ = std::make_shared<Var>(); | |||
| } | |||
| ~TfGeLUFusion() override = default; | |||
| private: | |||
| bool CheckPattern(const EquivPtr &equiv) const override; | |||
| const BaseRef DefinePattern() const override; | |||
| private: | |||
| VarPtr power_ = nullptr; | |||
| VarPtr power_y_ = nullptr; | |||
| VarPtr mul1_x_ = nullptr; | |||
| VarPtr mul2_x_ = nullptr; | |||
| VarPtr tanh_ = nullptr; | |||
| VarPtr add2_x_ = nullptr; | |||
| VarPtr mul3_x_ = nullptr; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_GELU_FUSION_H_ | |||
| @@ -98,11 +98,11 @@ AnfNodePtr TfLstmCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primi | |||
| VectorRef set_item = VectorRef({std::make_shared<Var>(""), placeholders[3], placeholders[2], new_hidden}); | |||
| auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, kPrimMakeTuple)); | |||
| auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMakeTuple)); | |||
| std::vector<BaseRef> outputs = {is_make_tuple, add3, placeholders[1], add2, set_item, output_cell, output_hidden}; | |||
| outputs.insert(outputs.end(), placeholders.begin() + 6, placeholders.end()); | |||
| VectorRef make_tuple_node = VectorRef(outputs); | |||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, kPrimReturn)); | |||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn)); | |||
| VectorRef return_node = VectorRef({is_return, make_tuple_node}); | |||
| VarPtr fg = std::make_shared<Var>("RootG"); | |||
| @@ -116,7 +116,7 @@ AnfNodePtr TfliteLstmCellFusion::GetCondGraphPattern(const PrimitiveVarMapPtr &p | |||
| auto is_less1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess)); | |||
| auto is_less2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess)); | |||
| auto is_logical_and = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLogicalAnd)); | |||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, kPrimReturn)); | |||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn)); | |||
| VectorRef less1_ref = VectorRef({is_less1, is_parameter1, is_parameter2}); | |||
| VectorRef less2_ref = VectorRef({is_less2, is_parameter3, is_parameter4}); | |||
| VectorRef logicaland_ref = VectorRef({is_logical_and, less1_ref, less2_ref}); | |||
| @@ -174,11 +174,11 @@ AnfNodePtr TfliteLstmCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &p | |||
| VectorRef set_item = VectorRef({std::make_shared<Var>("SetItem"), placeholders[3], placeholders[2], output}); | |||
| auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, kPrimMakeTuple)); | |||
| auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMakeTuple)); | |||
| std::vector<BaseRef> outputs = {is_make_tuple, add3, placeholders[1], add2, set_item, cell_output, hidden_output}; | |||
| outputs.insert(outputs.end(), placeholders.begin() + 6, placeholders.end()); | |||
| VectorRef make_tuple_node = VectorRef(outputs); | |||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, kPrimReturn)); | |||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn)); | |||
| VectorRef return_node = VectorRef({is_return, make_tuple_node}); | |||
| VarPtr fg = std::make_shared<Var>("RootG"); | |||
| @@ -41,8 +41,8 @@ ParamValueLitePtr NewParamValueLitePtr(lite::Tensor *tensor) { | |||
| bool IsSpecialType(const CNodePtr &cnode) { | |||
| if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || CheckPrimitiveType(cnode, prim::kPrimDepend) || | |||
| CheckPrimitiveType(cnode, prim::kPrimControlDepend) || CheckPrimitiveType(cnode, kPrimMakeTuple) || | |||
| CheckPrimitiveType(cnode, kPrimReturn) || CheckPrimitiveType(cnode, std::make_shared<Primitive>("While")) || | |||
| CheckPrimitiveType(cnode, prim::kPrimControlDepend) || CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || | |||
| CheckPrimitiveType(cnode, prim::kPrimReturn) || CheckPrimitiveType(cnode, std::make_shared<Primitive>("While")) || | |||
| CheckPrimitiveType(cnode, std::make_shared<Primitive>("If"))) { | |||
| return true; | |||
| } | |||
| @@ -81,7 +81,7 @@ bool WhilePass::Run(const FuncGraphPtr &graph) { | |||
| // concat body to cond | |||
| std::vector<AnfNodePtr> body_to_cond_inputs{cond_vnode}; | |||
| if (CheckPrimitiveType(body_output_cnode, kPrimMakeTuple)) { | |||
| if (CheckPrimitiveType(body_output_cnode, prim::kPrimMakeTuple)) { | |||
| for (size_t i = 1; i < body_output_cnode->inputs().size(); ++i) { | |||
| body_to_cond_inputs.emplace_back(body_output_cnode->input(i)); | |||
| } | |||