From: @xu_anyue Reviewed-by: @hangangqiang,@HilbertDavid Signed-off-by: @hangangqiangpull/13996/MERGE
| @@ -76,18 +76,12 @@ int SliceCPUKernel::SliceParallelRun(int thread_id) { | |||||
| } | } | ||||
| int SliceCPUKernel::Run() { | int SliceCPUKernel::Run() { | ||||
| auto ret = PreProcess(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "PreProcess fail!ret: " << ret; | |||||
| return ret; | |||||
| } | |||||
| if (param_->size_[1] < op_parameter_->thread_num_) { | if (param_->size_[1] < op_parameter_->thread_num_) { | ||||
| DoSliceNoParallel(in_tensors_.at(0)->data_c(), out_tensors_.at(0)->data_c(), param_, | DoSliceNoParallel(in_tensors_.at(0)->data_c(), out_tensors_.at(0)->data_c(), param_, | ||||
| lite::DataTypeSize(in_tensors_.at(0)->data_type())); | lite::DataTypeSize(in_tensors_.at(0)->data_type())); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| ret = ParallelLaunch(this->context_->thread_pool_, SliceLaunch, this, op_parameter_->thread_num_); | |||||
| auto ret = ParallelLaunch(this->context_->thread_pool_, SliceLaunch, this, op_parameter_->thread_num_); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "slice launch fail!ret: " << ret; | MS_LOG(ERROR) << "slice launch fail!ret: " << ret; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -96,6 +90,5 @@ int SliceCPUKernel::Run() { | |||||
| } | } | ||||
| REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_SliceFusion, LiteKernelCreator<SliceCPUKernel>) | REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_SliceFusion, LiteKernelCreator<SliceCPUKernel>) | ||||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_SliceFusion, LiteKernelCreator<SliceCPUKernel>) | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SliceFusion, LiteKernelCreator<SliceCPUKernel>) | REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SliceFusion, LiteKernelCreator<SliceCPUKernel>) | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -0,0 +1,75 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/fp16/slice_fp16.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "nnacl/base/slice_base.h" | |||||
| #include "nnacl/fp16/cast_fp16.h" | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_SliceFusion; | |||||
| namespace mindspore::kernel { | |||||
| int SliceFp16Launch(void *cdata, int task_id) { | |||||
| if (cdata == nullptr) { | |||||
| MS_LOG(ERROR) << "Input cdata is nullptr!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto kernel = reinterpret_cast<SliceFp16CPUKernel *>(cdata); | |||||
| return kernel->SliceFp16ParallelRun(task_id); | |||||
| } | |||||
| SliceFp16CPUKernel::~SliceFp16CPUKernel() { | |||||
| if (input_data_ != nullptr) { | |||||
| context_->allocator->Free(input_data_); | |||||
| input_data_ = nullptr; | |||||
| } | |||||
| } | |||||
| int SliceFp16CPUKernel::Init() { | |||||
| auto input_tensor = in_tensors_.at(0); | |||||
| if (input_tensor->data_type() == kNumberTypeFloat32 && input_tensor->data_c() != nullptr) { | |||||
| input_data_ = | |||||
| reinterpret_cast<float16_t *>(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t))); | |||||
| Float32ToFloat16(reinterpret_cast<float *>(input_tensor->data_c()), input_data_, input_tensor->ElementsNum()); | |||||
| } | |||||
| return SliceCPUKernel::Init(); | |||||
| } | |||||
| int SliceFp16CPUKernel::SliceFp16ParallelRun(int thread_id) { | |||||
| void *input_data = input_data_ == nullptr ? in_tensors_.at(0)->data_c() : input_data_; | |||||
| DoSlice(input_data, out_tensors_.at(0)->data_c(), param_, thread_id, lite::DataTypeSize(kNumberTypeFloat16)); | |||||
| return RET_OK; | |||||
| } | |||||
| int SliceFp16CPUKernel::Run() { | |||||
| void *input_data = input_data_ == nullptr ? in_tensors_.at(0)->data_c() : input_data_; | |||||
| if (param_->size_[1] < op_parameter_->thread_num_) { | |||||
| DoSliceNoParallel(input_data, out_tensors_.at(0)->data_c(), param_, lite::DataTypeSize(kNumberTypeFloat16)); | |||||
| return RET_OK; | |||||
| } | |||||
| auto ret = ParallelLaunch(this->context_->thread_pool_, SliceFp16Launch, this, op_parameter_->thread_num_); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "fp16 slice launch fail!ret: " << ret; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_SliceFusion, LiteKernelCreator<SliceFp16CPUKernel>) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,41 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SLICE_FP16_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SLICE_FP16_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "src/runtime/kernel/arm/base/slice_base.h" | |||||
| namespace mindspore::kernel { | |||||
| class SliceFp16CPUKernel : public SliceCPUKernel { | |||||
| public: | |||||
| SliceFp16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) | |||||
| : SliceCPUKernel(parameter, inputs, outputs, ctx) {} | |||||
| ~SliceFp16CPUKernel() override; | |||||
| int Init() override; | |||||
| int Run() override; | |||||
| int SliceFp16ParallelRun(int thread_id); | |||||
| private: | |||||
| float16_t *input_data_ = nullptr; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SLICE_FP16_H_ | |||||
| @@ -243,6 +243,9 @@ if(ENABLE_CONVERTER) | |||||
| ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc | ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc | ||||
| ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.cc | ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.cc | ||||
| ${LITE_DIR}/tools/optimizer/fusion/matmul_add_fusion.cc | ${LITE_DIR}/tools/optimizer/fusion/matmul_add_fusion.cc | ||||
| ${LITE_DIR}/tools/optimizer/fusion/gelu_fusion.cc | |||||
| ${LITE_DIR}/tools/optimizer/fusion/tf_gelu_fusion.cc | |||||
| ${LITE_DIR}/tools/optimizer/fusion/onnx_gelu_fusion.cc | |||||
| ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc | ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc | ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc | ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc | ||||
| @@ -98,7 +98,7 @@ void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { | |||||
| MS_LOG(ERROR) << "value node is invalid."; | MS_LOG(ERROR) << "value node is invalid."; | ||||
| return; | return; | ||||
| } | } | ||||
| if (value_node->value() != nullptr && (opt::CheckPrimitiveType(make_tuple_node, opt::kPrimMakeTuple) || | |||||
| if (value_node->value() != nullptr && (opt::CheckPrimitiveType(make_tuple_node, prim::kPrimMakeTuple) || | |||||
| opt::CheckPrimitiveType(make_tuple_node, opt::kPrimMakeTupleV2))) { | opt::CheckPrimitiveType(make_tuple_node, opt::kPrimMakeTupleV2))) { | ||||
| has_make_tuple = true; | has_make_tuple = true; | ||||
| for (size_t j = 1; j < make_tuple_node->inputs().size(); ++j) { | for (size_t j = 1; j < make_tuple_node->inputs().size(); ++j) { | ||||
| @@ -372,7 +372,7 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc | |||||
| ret = RET_MEMORY_FAILED; | ret = RET_MEMORY_FAILED; | ||||
| break; | break; | ||||
| } | } | ||||
| if (opt::CheckPrimitiveType(cnode, opt::kPrimReturn)) { | |||||
| if (opt::CheckPrimitiveType(cnode, prim::kPrimReturn)) { | |||||
| node->name = mindspore::ops::kNameReturn; | node->name = mindspore::ops::kNameReturn; | ||||
| ret = SetGraphoutputIndex(cnode, subgraph_index, meta_graphT, sub_graphT, node.get()); | ret = SetGraphoutputIndex(cnode, subgraph_index, meta_graphT, sub_graphT, node.get()); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| @@ -53,6 +53,9 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| ../optimizer/fusion/tf_bidirection_gru_fusion.cc | ../optimizer/fusion/tf_bidirection_gru_fusion.cc | ||||
| ../optimizer/fusion/tf_bidirection_gru_cf_fusion.cc | ../optimizer/fusion/tf_bidirection_gru_cf_fusion.cc | ||||
| ../optimizer/fusion/matmul_add_fusion.cc | ../optimizer/fusion/matmul_add_fusion.cc | ||||
| ../optimizer/fusion/gelu_fusion.cc | |||||
| ../optimizer/fusion/tf_gelu_fusion.cc | |||||
| ../optimizer/fusion/onnx_gelu_fusion.cc | |||||
| ../optimizer/graph/weight_format_transform_pass.cc | ../optimizer/graph/weight_format_transform_pass.cc | ||||
| ../optimizer/graph/weight_format_hardcode_pass.cc | ../optimizer/graph/weight_format_hardcode_pass.cc | ||||
| ../optimizer/graph/clip_convert_activation_pass.cc | ../optimizer/graph/clip_convert_activation_pass.cc | ||||
| @@ -37,6 +37,8 @@ | |||||
| #include "tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.h" | #include "tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.h" | ||||
| #include "tools/optimizer/fusion/matmul_add_fusion.h" | #include "tools/optimizer/fusion/matmul_add_fusion.h" | ||||
| #include "tools/optimizer/graph/primitive_adjust_pass.h" | #include "tools/optimizer/graph/primitive_adjust_pass.h" | ||||
| #include "tools/optimizer/fusion/tf_gelu_fusion.h" | |||||
| #include "tools/optimizer/fusion/onnx_gelu_fusion.h" | |||||
| #include "tools/optimizer/graph/mindir_adjust_pass.h" | #include "tools/optimizer/graph/mindir_adjust_pass.h" | ||||
| #include "tools/optimizer/graph/redundant_op_remove_pass.h" | #include "tools/optimizer/graph/redundant_op_remove_pass.h" | ||||
| #include "tools/optimizer/graph/weight_format_hardcode_pass.h" | #include "tools/optimizer/graph/weight_format_hardcode_pass.h" | ||||
| @@ -89,6 +91,8 @@ int AnfTransform::AddFusionPass(const std::shared_ptr<opt::GraphOptimizer> &opti | |||||
| fusion_pm->AddPass(std::make_shared<opt::TfliteLstmCellFusion>()); | fusion_pm->AddPass(std::make_shared<opt::TfliteLstmCellFusion>()); | ||||
| fusion_pm->AddPass(std::make_shared<opt::TfLstmCellFusion>()); | fusion_pm->AddPass(std::make_shared<opt::TfLstmCellFusion>()); | ||||
| fusion_pm->AddPass(std::make_shared<opt::TfBidirectionGruFusion>()); | fusion_pm->AddPass(std::make_shared<opt::TfBidirectionGruFusion>()); | ||||
| fusion_pm->AddPass(std::make_shared<opt::TfGeLUFusion>()); | |||||
| fusion_pm->AddPass(std::make_shared<opt::OnnxGeLUFusion>()); | |||||
| } | } | ||||
| if (config->fmk == lite::converter::FmkType_MS) { | if (config->fmk == lite::converter::FmkType_MS) { | ||||
| auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>(); | auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>(); | ||||
| @@ -54,10 +54,10 @@ bool IsRealKernel(const AnfNodePtr &node) { | |||||
| auto input = cnode->inputs()[0]; | auto input = cnode->inputs()[0]; | ||||
| bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) || | bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) || | ||||
| IsPrimitive(input, prim::kPrimTensorSummary) || | IsPrimitive(input, prim::kPrimTensorSummary) || | ||||
| IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, kPrimMakeTuple) || | |||||
| IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) || | |||||
| IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) || | IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) || | ||||
| IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) || | IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) || | ||||
| IsPrimitive(input, kPrimReturn) || IsPrimitive(input, prim::kPrimPartial); | |||||
| IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial); | |||||
| return !is_virtual_node; | return !is_virtual_node; | ||||
| } | } | ||||
| @@ -335,7 +335,7 @@ bool IsRealCNodeKernel(const AnfNodePtr &node) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| // return considered as a real node | // return considered as a real node | ||||
| if (CheckPrimitiveType(node, kPrimReturn)) { | |||||
| if (CheckPrimitiveType(node, prim::kPrimReturn)) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| return IsRealKernel(node); | return IsRealKernel(node); | ||||
| @@ -35,8 +35,8 @@ using mindspore::lite::RET_OK; | |||||
| using mindspore::lite::STATUS; | using mindspore::lite::STATUS; | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("Return"); | |||||
| inline const PrimitivePtr kPrimMakeTuple = std::make_shared<Primitive>("MakeTuple"); | |||||
| inline const PrimitivePtr kPrimDivFusion = std::make_shared<Primitive>("DivFusion"); | |||||
| inline const PrimitivePtr kPrimErf = std::make_shared<Primitive>("Erf"); | |||||
| inline const PrimitivePtr kPrimMakeTupleV2 = std::make_shared<Primitive>("make_tuple"); | inline const PrimitivePtr kPrimMakeTupleV2 = std::make_shared<Primitive>("make_tuple"); | ||||
| inline const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("Identity"); | inline const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("Identity"); | ||||
| std::vector<int> CastToInt(const ValuePtr &value); | std::vector<int> CastToInt(const ValuePtr &value); | ||||
| @@ -145,6 +145,15 @@ ParameterPtr BuildIntVec2DParameterNode(const FuncGraphPtr &func_graph, const st | |||||
| ParameterPtr BuildFloatValueParameterNode(const FuncGraphPtr &func_graph, const float &data, | ParameterPtr BuildFloatValueParameterNode(const FuncGraphPtr &func_graph, const float &data, | ||||
| const std::string &node_name); | const std::string &node_name); | ||||
| template <const PrimitivePtr *prim = nullptr> | |||||
| inline bool IsSpecifiedNode(const BaseRef &n) { | |||||
| if (utils::isa<AnfNodePtr>(n)) { | |||||
| auto anf_node = utils::cast<AnfNodePtr>(n); | |||||
| return CheckPrimitiveType(anf_node, *prim); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_GLLO_UTILS_H_ | #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_GLLO_UTILS_H_ | ||||
| @@ -0,0 +1,85 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/optimizer/fusion/gelu_fusion.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "ops/fusion/activation.h" | |||||
| #include "utils/utils.h" | |||||
| #include "tools/optimizer/common/gllo_utils.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| CNodePtr GeLUFusion::CreateGeLUNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const EquivPtr &equiv) const { | |||||
| MS_ASSERT(func_graph != nullptr); | |||||
| MS_ASSERT(node != nullptr); | |||||
| auto gelu_prim = std::make_shared<ops::Activation>(); | |||||
| gelu_prim->set_activation_type(mindspore::GELU); | |||||
| auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_]); | |||||
| MS_ASSERT(input_node != nullptr); | |||||
| auto gelu_cnode = func_graph->NewCNode(gelu_prim, {input_node}); | |||||
| gelu_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_gelu"); | |||||
| gelu_cnode->set_abstract(node->abstract()->Clone()); | |||||
| return gelu_cnode; | |||||
| } | |||||
| const float GeLUFusion::GetParameterValue(const EquivPtr &equiv, const VarPtr &input) const { | |||||
| MS_ASSERT(equiv != nullptr); | |||||
| MS_ASSERT(input != nullptr); | |||||
| float value = -1; | |||||
| auto node = utils::cast<AnfNodePtr>((*equiv)[input]); | |||||
| if (node == nullptr || !utils::isa<ParameterPtr>(node)) { | |||||
| return value; | |||||
| } | |||||
| auto parameter_node = node->cast<ParameterPtr>(); | |||||
| if (!parameter_node->has_default() || parameter_node->default_param() == nullptr) { | |||||
| return value; | |||||
| } | |||||
| auto param_value_lite = parameter_node->default_param()->cast<ParamValueLitePtr>(); | |||||
| if (param_value_lite == nullptr) { | |||||
| return value; | |||||
| } | |||||
| if (param_value_lite->tensor_type() != kNumberTypeFloat32 && param_value_lite->tensor_type() != kNumberTypeFloat) { | |||||
| return value; | |||||
| } | |||||
| if (param_value_lite->tensor_size() != sizeof(float)) { | |||||
| return value; | |||||
| } | |||||
| return *static_cast<float *>(param_value_lite->tensor_addr()); | |||||
| } | |||||
| const AnfNodePtr GeLUFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const EquivPtr &equiv) const { | |||||
| MS_ASSERT(func_graph != nullptr); | |||||
| MS_ASSERT(node != nullptr); | |||||
| MS_ASSERT(equiv != nullptr); | |||||
| MS_LOG(DEBUG) << "gelu_fusion pass"; | |||||
| if (!utils::isa<CNodePtr>(node)) { | |||||
| return nullptr; | |||||
| } | |||||
| if (!CheckPattern(equiv)) { | |||||
| return nullptr; | |||||
| } | |||||
| auto cnode = CreateGeLUNode(func_graph, node, equiv); | |||||
| if (cnode == nullptr) { | |||||
| MS_LOG(DEBUG) << "new gelu node failed."; | |||||
| return nullptr; | |||||
| } | |||||
| return cnode; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,48 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_GELU_FUSION_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_GELU_FUSION_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| #include "utils/utils.h" | |||||
| #include "tools/optimizer/common/gllo_utils.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class GeLUFusion : public PatternProcessPass { | |||||
| public: | |||||
| explicit GeLUFusion(const std::string &name = "gelu_fusion", bool multigraph = true) | |||||
| : PatternProcessPass(name, multigraph), input_(std::make_shared<Var>()) {} | |||||
| ~GeLUFusion() override = default; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| protected: | |||||
| virtual bool CheckPattern(const EquivPtr &equiv) const = 0; | |||||
| const float GetParameterValue(const EquivPtr &equiv, const VarPtr &input) const; | |||||
| VarPtr input_ = nullptr; | |||||
| private: | |||||
| CNodePtr CreateGeLUNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv) const; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_GELU_FUSION_H_ | |||||
| @@ -20,19 +20,19 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | namespace { | ||||
| constexpr size_t AddInputSize = 3; | |||||
| constexpr size_t MatMulInputSize = 3; | |||||
| constexpr size_t kAddInputSize = 3; | |||||
| constexpr size_t kMatMulInputSize = 3; | |||||
| bool CheckAndGetMatMulIndex(const CNodePtr &cnode, size_t *index) { | bool CheckAndGetMatMulIndex(const CNodePtr &cnode, size_t *index) { | ||||
| MS_ASSERT(cnode != nullptr); | MS_ASSERT(cnode != nullptr); | ||||
| MS_ASSERT(index != nullptr); | MS_ASSERT(index != nullptr); | ||||
| if (cnode->size() != AddInputSize) { | |||||
| if (cnode->size() != kAddInputSize) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| size_t matmul_index = 0; | size_t matmul_index = 0; | ||||
| for (size_t i = 1; i < cnode->size(); ++i) { | for (size_t i = 1; i < cnode->size(); ++i) { | ||||
| if (CheckPrimitiveType(cnode->input(i), prim::kPrimMatMul)) { | if (CheckPrimitiveType(cnode->input(i), prim::kPrimMatMul)) { | ||||
| auto matmul_cnode = cnode->input(i)->cast<CNodePtr>(); | auto matmul_cnode = cnode->input(i)->cast<CNodePtr>(); | ||||
| if (matmul_cnode->size() > MatMulInputSize) { | |||||
| if (matmul_cnode->size() > kMatMulInputSize) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| matmul_index = i; | matmul_index = i; | ||||
| @@ -63,7 +63,7 @@ bool MatMulAddFusion::Run(const FuncGraphPtr &func_graph) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto matmul_cnode = cnode->input(index)->cast<CNodePtr>(); | auto matmul_cnode = cnode->input(index)->cast<CNodePtr>(); | ||||
| auto bias_node = cnode->input(AddInputSize - index); | |||||
| auto bias_node = cnode->input(kAddInputSize - index); | |||||
| if (!utils::isa<Parameter>(bias_node) || !bias_node->cast<ParameterPtr>()->default_param()) { | if (!utils::isa<Parameter>(bias_node) || !bias_node->cast<ParameterPtr>()->default_param()) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -17,7 +17,6 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include "ops/fusion/layer_norm_fusion.h" | #include "ops/fusion/layer_norm_fusion.h" | ||||
| #include "ops/fusion/reduce_fusion.h" | #include "ops/fusion/reduce_fusion.h" | ||||
| #include "ops/rsqrt.h" | |||||
| #include "mindspore/core/ops/instance_norm.h" | #include "mindspore/core/ops/instance_norm.h" | ||||
| #include "src/param_value_lite.h" | #include "src/param_value_lite.h" | ||||
| #include "utils/utils.h" | #include "utils/utils.h" | ||||
| @@ -27,60 +26,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | namespace { | ||||
| inline bool IsAddNode(const BaseRef &n) { | |||||
| if (utils::isa<AnfNodePtr>(n)) { | |||||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimAddFusion); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| inline bool IsSquaredDifferenceNode(const BaseRef &n) { | |||||
| if (utils::isa<AnfNodePtr>(n)) { | |||||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSquaredDifference); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| inline bool IsRsqrtNode(const BaseRef &n) { | |||||
| if (utils::isa<AnfNodePtr>(n)) { | |||||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimRsqrt); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| inline bool IsMulNode(const BaseRef &n) { | |||||
| if (utils::isa<AnfNodePtr>(n)) { | |||||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimMulFusion); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| inline bool IsSubNode(const BaseRef &n) { | |||||
| if (utils::isa<AnfNodePtr>(n)) { | |||||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSubFusion); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| inline bool IsPowNode(const BaseRef &n) { | |||||
| if (utils::isa<AnfNodePtr>(n)) { | |||||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimPowFusion); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| inline bool IsSqrtNode(const BaseRef &n) { | |||||
| if (utils::isa<AnfNodePtr>(n)) { | |||||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSqrt); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| inline bool IsDivNode(const BaseRef &n) { | |||||
| if (utils::isa<AnfNodePtr>(n)) { | |||||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimDiv) || | |||||
| CheckPrimitiveType(utils::cast<AnfNodePtr>(n), std::make_shared<Primitive>("DivFusion")); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| STATUS GetReduceAxes(const BaseRef &n, std::vector<int> *axes) { | STATUS GetReduceAxes(const BaseRef &n, std::vector<int> *axes) { | ||||
| MS_ASSERT(node != nullptr); | MS_ASSERT(node != nullptr); | ||||
| if (utils::isa<ParameterPtr>(n)) { | if (utils::isa<ParameterPtr>(n)) { | ||||
| @@ -195,7 +140,7 @@ bool NormFusion::GetNormTypeAndAxis(const CNodePtr &input_cnode, const std::vect | |||||
| } | } | ||||
| } | } | ||||
| if (mean_axes.back() >= 0 && mean_axes.back() + 1 != static_cast<int>(shape.size())) { | if (mean_axes.back() >= 0 && mean_axes.back() + 1 != static_cast<int>(shape.size())) { | ||||
| MS_LOG(DEBUG) << "mean node is not reduce to last axis"; | |||||
| MS_LOG(DEBUG) << "mean node is not reduce to last axis."; | |||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -318,37 +263,41 @@ const AnfNodePtr NormFusion::Process(const FuncGraphPtr &func_graph, const AnfNo | |||||
| const BaseRef TfNormFusion::DefinePattern() const { | const BaseRef TfNormFusion::DefinePattern() const { | ||||
| VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_}); | VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_}); | ||||
| auto squared_diffference1 = std::make_shared<CondVar>(IsSquaredDifferenceNode); | |||||
| auto squared_diffference1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSquaredDifference>); | |||||
| VectorRef squared_diffference1_ref = VectorRef({squared_diffference1, input_, mean1_ref}); | VectorRef squared_diffference1_ref = VectorRef({squared_diffference1, input_, mean1_ref}); | ||||
| auto mul1 = std::make_shared<CondVar>(IsMulNode); | |||||
| auto mul1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>); | |||||
| VectorRef mean2_ref = VectorRef({mean2_, squared_diffference1_ref, mean2_axes_}); | VectorRef mean2_ref = VectorRef({mean2_, squared_diffference1_ref, mean2_axes_}); | ||||
| auto add1 = std::make_shared<CondVar>(IsAddNode); | |||||
| auto add1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>); | |||||
| VectorRef add1_ref = VectorRef({add1, mean2_ref, epsilon_}); | VectorRef add1_ref = VectorRef({add1, mean2_ref, epsilon_}); | ||||
| auto rsqrt1 = std::make_shared<CondVar>(IsRsqrtNode); | |||||
| auto rsqrt1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimRsqrt>); | |||||
| VectorRef rsqrt1_ref = VectorRef({rsqrt1, add1_ref}); | VectorRef rsqrt1_ref = VectorRef({rsqrt1, add1_ref}); | ||||
| auto mul2 = std::make_shared<CondVar>(IsMulNode); | |||||
| auto mul2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>); | |||||
| VectorRef mul2_ref = VectorRef({mul2, rsqrt1_ref, gamma_}); | VectorRef mul2_ref = VectorRef({mul2, rsqrt1_ref, gamma_}); | ||||
| VectorRef mul1_ref = VectorRef({mul1, input_, mul2_ref}); | VectorRef mul1_ref = VectorRef({mul1, input_, mul2_ref}); | ||||
| auto mul3 = std::make_shared<CondVar>(IsMulNode); | |||||
| auto mul3 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>); | |||||
| VectorRef mul3_ref = VectorRef({mul3, mean1_ref, mul2_ref}); | VectorRef mul3_ref = VectorRef({mul3, mean1_ref, mul2_ref}); | ||||
| auto sub1 = std::make_shared<CondVar>(IsSubNode); | |||||
| auto sub1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSubFusion>); | |||||
| VectorRef sub1_ref = VectorRef({sub1, beta_, mul3_ref}); | VectorRef sub1_ref = VectorRef({sub1, beta_, mul3_ref}); | ||||
| auto add2 = std::make_shared<CondVar>(IsAddNode); | |||||
| auto add2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>); | |||||
| VectorRef add2_ref = VectorRef({add2, mul1_ref, sub1_ref}); | VectorRef add2_ref = VectorRef({add2, mul1_ref, sub1_ref}); | ||||
| return add2_ref; | return add2_ref; | ||||
| } | } | ||||
| const BaseRef OnnxLayerNormFusion::DefinePattern() const { | const BaseRef OnnxLayerNormFusion::DefinePattern() const { | ||||
| VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_}); | VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_}); | ||||
| VectorRef sub1_ref = VectorRef({std::make_shared<CondVar>(IsSubNode), input_, mean1_ref}); | |||||
| VectorRef sub2_ref = VectorRef({std::make_shared<CondVar>(IsSubNode), input_, mean1_ref}); | |||||
| VectorRef pow_ref = VectorRef({std::make_shared<CondVar>(IsPowNode), sub2_ref, std::make_shared<Var>()}); | |||||
| VectorRef sub1_ref = | |||||
| VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSubFusion>), input_, mean1_ref}); | |||||
| VectorRef sub2_ref = | |||||
| VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSubFusion>), input_, mean1_ref}); | |||||
| VectorRef pow_ref = | |||||
| VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimPowFusion>), sub2_ref, std::make_shared<Var>()}); | |||||
| VectorRef mean2_ref = VectorRef({mean2_, pow_ref, mean2_axes_}); | VectorRef mean2_ref = VectorRef({mean2_, pow_ref, mean2_axes_}); | ||||
| VectorRef add1_ref = VectorRef({std::make_shared<CondVar>(IsAddNode), mean2_ref, epsilon_}); | |||||
| VectorRef sqrt_ref = VectorRef({std::make_shared<CondVar>(IsSqrtNode), add1_ref}); | |||||
| VectorRef div_ref = VectorRef({std::make_shared<CondVar>(IsDivNode), sub1_ref, sqrt_ref}); | |||||
| VectorRef mul_ref = VectorRef({std::make_shared<CondVar>(IsMulNode), gamma_, div_ref}); | |||||
| VectorRef add2_ref = VectorRef({std::make_shared<CondVar>(IsAddNode), mul_ref, beta_}); | |||||
| VectorRef add1_ref = | |||||
| VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), mean2_ref, epsilon_}); | |||||
| VectorRef sqrt_ref = VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSqrt>), add1_ref}); | |||||
| VectorRef div_ref = VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&kPrimDivFusion>), sub1_ref, sqrt_ref}); | |||||
| VectorRef mul_ref = VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), gamma_, div_ref}); | |||||
| VectorRef add2_ref = VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), mul_ref, beta_}); | |||||
| return add2_ref; | return add2_ref; | ||||
| } | } | ||||
| } // namespace opt | } // namespace opt | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_NORM_FUSION_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_NORM_FUSION_H_ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| @@ -31,7 +31,7 @@ namespace opt { | |||||
| /// fuse layer_norm or instance_norm into one operator | /// fuse layer_norm or instance_norm into one operator | ||||
| class NormFusion : public PatternProcessPass { | class NormFusion : public PatternProcessPass { | ||||
| public: | public: | ||||
| explicit NormFusion(const std::string &name = "tf_norm_fusion", bool multigraph = true) | |||||
| explicit NormFusion(const std::string &name = "norm_fusion", bool multigraph = true) | |||||
| : PatternProcessPass(name, multigraph) { | : PatternProcessPass(name, multigraph) { | ||||
| input_ = std::make_shared<Var>(); | input_ = std::make_shared<Var>(); | ||||
| mean1_ = std::make_shared<Var>(); | mean1_ = std::make_shared<Var>(); | ||||
| @@ -44,7 +44,6 @@ class NormFusion : public PatternProcessPass { | |||||
| } | } | ||||
| ~NormFusion() override = default; | ~NormFusion() override = default; | ||||
| virtual const BaseRef DefinePattern() const = 0; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | ||||
| private: | private: | ||||
| @@ -70,6 +69,9 @@ class NormFusion : public PatternProcessPass { | |||||
| /// fuse tf layer_norm or instance_norm into one operator | /// fuse tf layer_norm or instance_norm into one operator | ||||
| class TfNormFusion : public NormFusion { | class TfNormFusion : public NormFusion { | ||||
| public: | public: | ||||
| explicit TfNormFusion(const std::string &name = "tf_norm_fusion", bool multigraph = true) | |||||
| : NormFusion(name, multigraph) {} | |||||
| ~TfNormFusion() override = default; | ~TfNormFusion() override = default; | ||||
| const BaseRef DefinePattern() const override; | const BaseRef DefinePattern() const override; | ||||
| }; | }; | ||||
| @@ -77,11 +79,13 @@ class TfNormFusion : public NormFusion { | |||||
| /// fuse onnx layer_norm into one operator | /// fuse onnx layer_norm into one operator | ||||
| class OnnxLayerNormFusion : public NormFusion { | class OnnxLayerNormFusion : public NormFusion { | ||||
| public: | public: | ||||
| explicit OnnxLayerNormFusion(const std::string &name = "onnx_layer_norm_fusion", bool multigraph = true) | |||||
| : NormFusion(name, multigraph) {} | |||||
| ~OnnxLayerNormFusion() override = default; | ~OnnxLayerNormFusion() override = default; | ||||
| const BaseRef DefinePattern() const override; | const BaseRef DefinePattern() const override; | ||||
| }; | }; | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_ | |||||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_NORM_FUSION_H_ | |||||
| @@ -0,0 +1,55 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/optimizer/fusion/onnx_gelu_fusion.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| constexpr float DIFF_THRESHOLD = 0.0001; | |||||
| constexpr float DIV_Y = 1.41421; | |||||
| constexpr float ADD_Y = 1.0; | |||||
| constexpr float MUL1_y = 0.5; | |||||
| } // namespace | |||||
| // gelu(x) = 1/2 * x * [1 + erf(x / sqrt(2))] | |||||
| const BaseRef OnnxGeLUFusion::DefinePattern() const { | |||||
| VectorRef div_ref({std::make_shared<CondVar>(IsSpecifiedNode<&kPrimDivFusion>), input_, div_y_}); | |||||
| VectorRef erf_ref({std::make_shared<CondVar>(IsSpecifiedNode<&kPrimErf>), div_ref}); | |||||
| VectorRef add_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), erf_ref, add_y_}); | |||||
| VectorRef mul1_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), input_, mul1_y_}); | |||||
| VectorRef mul2_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), mul1_ref, add_ref}); | |||||
| return mul2_ref; | |||||
| } | |||||
| bool OnnxGeLUFusion::CheckPattern(const EquivPtr &equiv) const { | |||||
| MS_ASSERT(equiv != nullptr); | |||||
| float div_y = GetParameterValue(equiv, div_y_); | |||||
| if (div_y < 0 || fabs(div_y - DIV_Y) > DIFF_THRESHOLD) { | |||||
| return false; | |||||
| } | |||||
| float add_y = GetParameterValue(equiv, add_y_); | |||||
| if (add_y < 0 || fabs(add_y - ADD_Y) > DIFF_THRESHOLD) { | |||||
| return false; | |||||
| } | |||||
| float mul1_y = GetParameterValue(equiv, mul1_y_); | |||||
| if (mul1_y < 0 || fabs(mul1_y - MUL1_y) > DIFF_THRESHOLD) { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,49 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_GELU_FUSION_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_GELU_FUSION_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "tools/optimizer/fusion/gelu_fusion.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class OnnxGeLUFusion : public GeLUFusion { | |||||
| public: | |||||
| explicit OnnxGeLUFusion(const std::string &name = "onnx_gelu_fusion", bool multigraph = true) | |||||
| : GeLUFusion(name, multigraph) { | |||||
| div_y_ = std::make_shared<Var>(); | |||||
| add_y_ = std::make_shared<Var>(); | |||||
| mul1_y_ = std::make_shared<Var>(); | |||||
| } | |||||
| ~OnnxGeLUFusion() override = default; | |||||
| private: | |||||
| bool CheckPattern(const EquivPtr &equiv) const override; | |||||
| const BaseRef DefinePattern() const override; | |||||
| private: | |||||
| VarPtr div_y_ = nullptr; | |||||
| VarPtr add_y_ = nullptr; | |||||
| VarPtr mul1_y_ = nullptr; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_GELU_FUSION_H_ | |||||
| @@ -133,7 +133,7 @@ AnfNodePtr TfBidirectionGruFusion::GetCondGraphPattern(const PrimitiveVarMapPtr | |||||
| auto is_less1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess)); | auto is_less1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess)); | ||||
| auto is_less2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess)); | auto is_less2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess)); | ||||
| auto is_logical_and = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLogicalAnd)); | auto is_logical_and = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLogicalAnd)); | ||||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, kPrimReturn)); | |||||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn)); | |||||
| VectorRef less1_ref = VectorRef({is_less1, is_parameter1, is_parameter2}); | VectorRef less1_ref = VectorRef({is_less1, is_parameter1, is_parameter2}); | ||||
| VectorRef less2_ref = VectorRef({is_less2, is_parameter3, is_parameter4}); | VectorRef less2_ref = VectorRef({is_less2, is_parameter3, is_parameter4}); | ||||
| VectorRef logicaland_ref = VectorRef({is_logical_and, less1_ref, less2_ref}); | VectorRef logicaland_ref = VectorRef({is_logical_and, less1_ref, less2_ref}); | ||||
| @@ -183,13 +183,13 @@ AnfNodePtr TfBidirectionGruFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr | |||||
| VectorRef select_hidden = VectorRef({std::make_shared<Var>("Switch"), greater_equal, placeholders[4], new_hidden}); | VectorRef select_hidden = VectorRef({std::make_shared<Var>("Switch"), greater_equal, placeholders[4], new_hidden}); | ||||
| auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, kPrimMakeTuple)); | |||||
| auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMakeTuple)); | |||||
| std::vector<BaseRef> outputs = {is_make_tuple, add1, placeholders[1], add, | std::vector<BaseRef> outputs = {is_make_tuple, add1, placeholders[1], add, | ||||
| output, select_hidden, placeholders[5], placeholders[6], | output, select_hidden, placeholders[5], placeholders[6], | ||||
| placeholders[7]}; | placeholders[7]}; | ||||
| outputs.insert(outputs.end(), placeholders.begin() + 8, placeholders.end()); | outputs.insert(outputs.end(), placeholders.begin() + 8, placeholders.end()); | ||||
| VectorRef make_tuple_node = VectorRef(outputs); | VectorRef make_tuple_node = VectorRef(outputs); | ||||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, kPrimReturn)); | |||||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn)); | |||||
| VectorRef return_node = VectorRef({is_return, make_tuple_node}); | VectorRef return_node = VectorRef({is_return, make_tuple_node}); | ||||
| VarPtr fg = std::make_shared<Var>("RootG"); | VarPtr fg = std::make_shared<Var>("RootG"); | ||||
| @@ -0,0 +1,88 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/optimizer/fusion/tf_gelu_fusion.h" | |||||
| #include "ops/op_utils.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| constexpr float DIFF_THRESHOLD = 0.0001; | |||||
| constexpr float POW_Y = 3; | |||||
| constexpr float MUL1_Y = 0.044715; | |||||
| constexpr float MUL2_X = 0.79788; | |||||
| constexpr float ADD2_X = 1.0; | |||||
| constexpr float MUL3_X = 0.5; | |||||
| bool CheckTanh(const EquivPtr &equiv, const VarPtr &input) { | |||||
| MS_ASSERT(equiv != nullptr); | |||||
| MS_ASSERT(input != nullptr); | |||||
| auto anf_node = utils::cast<AnfNodePtr>((*equiv)[input]); | |||||
| MS_ASSERT(anf_node != nullptr); | |||||
| AnfNodePtr value_node = anf_node; | |||||
| if (anf_node->isa<CNode>()) { | |||||
| value_node = anf_node->cast<CNodePtr>()->input(0); | |||||
| } | |||||
| auto act_prim = GetValueNode<PrimitivePtr>(value_node); | |||||
| if (act_prim == nullptr) { | |||||
| return false; | |||||
| } | |||||
| return act_prim->GetAttr(ops::kActivationType) != nullptr && | |||||
| GetValue<int64_t>(act_prim->GetAttr(ops::kActivationType)) == mindspore::TANH; | |||||
| } | |||||
| } // namespace | |||||
| // gelu(x) = 1/2 * x * [1 + tanh(0.79788 * (x + 0.044715 * x ^ 3))] | |||||
| const BaseRef TfGeLUFusion::DefinePattern() const { | |||||
| VectorRef pow_ref({power_, input_, power_y_}); | |||||
| VectorRef mul1_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), mul1_x_, pow_ref}); | |||||
| VectorRef add1_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), input_, mul1_ref}); | |||||
| VectorRef mul2_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), mul2_x_, add1_ref}); | |||||
| VectorRef tanh_ref({tanh_, mul2_ref}); | |||||
| VectorRef add2_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), add2_x_, tanh_ref}); | |||||
| VectorRef mul3_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), mul3_x_, add2_ref}); | |||||
| VectorRef mul4_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), input_, mul3_ref}); | |||||
| return mul4_ref; | |||||
| } | |||||
| bool TfGeLUFusion::CheckPattern(const EquivPtr &equiv) const { | |||||
| MS_ASSERT(equiv != nullptr); | |||||
| if (!CheckTanh(equiv, tanh_)) { | |||||
| return false; | |||||
| } | |||||
| float pow_y = GetParameterValue(equiv, power_y_); | |||||
| if (pow_y < 0 || fabs(pow_y - POW_Y) > DIFF_THRESHOLD) { | |||||
| return false; | |||||
| } | |||||
| float mul1_y = GetParameterValue(equiv, mul1_x_); | |||||
| if (mul1_y < 0 || fabs(mul1_y - MUL1_Y) > DIFF_THRESHOLD) { | |||||
| return false; | |||||
| } | |||||
| float mul2_x = GetParameterValue(equiv, mul2_x_); | |||||
| if (mul2_x < 0 || fabs(mul2_x - MUL2_X) > DIFF_THRESHOLD) { | |||||
| return false; | |||||
| } | |||||
| float add2_x = GetParameterValue(equiv, add2_x_); | |||||
| if (add2_x < 0 || fabs(add2_x - ADD2_X) > DIFF_THRESHOLD) { | |||||
| return false; | |||||
| } | |||||
| float mul3_x = GetParameterValue(equiv, mul3_x_); | |||||
| if (mul3_x < 0 || fabs(mul3_x - MUL3_X) > DIFF_THRESHOLD) { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,57 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_GELU_FUSION_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_GELU_FUSION_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "tools/optimizer/fusion/gelu_fusion.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class TfGeLUFusion : public GeLUFusion { | |||||
| public: | |||||
| explicit TfGeLUFusion(const std::string &name = "tf_gelu_fusion", bool multigraph = true) | |||||
| : GeLUFusion(name, multigraph) { | |||||
| power_ = std::make_shared<Var>(); | |||||
| power_y_ = std::make_shared<Var>(); | |||||
| mul1_x_ = std::make_shared<Var>(); | |||||
| mul2_x_ = std::make_shared<Var>(); | |||||
| tanh_ = std::make_shared<Var>(); | |||||
| add2_x_ = std::make_shared<Var>(); | |||||
| mul3_x_ = std::make_shared<Var>(); | |||||
| } | |||||
| ~TfGeLUFusion() override = default; | |||||
| private: | |||||
| bool CheckPattern(const EquivPtr &equiv) const override; | |||||
| const BaseRef DefinePattern() const override; | |||||
| private: | |||||
| VarPtr power_ = nullptr; | |||||
| VarPtr power_y_ = nullptr; | |||||
| VarPtr mul1_x_ = nullptr; | |||||
| VarPtr mul2_x_ = nullptr; | |||||
| VarPtr tanh_ = nullptr; | |||||
| VarPtr add2_x_ = nullptr; | |||||
| VarPtr mul3_x_ = nullptr; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_GELU_FUSION_H_ | |||||
| @@ -98,11 +98,11 @@ AnfNodePtr TfLstmCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primi | |||||
| VectorRef set_item = VectorRef({std::make_shared<Var>(""), placeholders[3], placeholders[2], new_hidden}); | VectorRef set_item = VectorRef({std::make_shared<Var>(""), placeholders[3], placeholders[2], new_hidden}); | ||||
| auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, kPrimMakeTuple)); | |||||
| auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMakeTuple)); | |||||
| std::vector<BaseRef> outputs = {is_make_tuple, add3, placeholders[1], add2, set_item, output_cell, output_hidden}; | std::vector<BaseRef> outputs = {is_make_tuple, add3, placeholders[1], add2, set_item, output_cell, output_hidden}; | ||||
| outputs.insert(outputs.end(), placeholders.begin() + 6, placeholders.end()); | outputs.insert(outputs.end(), placeholders.begin() + 6, placeholders.end()); | ||||
| VectorRef make_tuple_node = VectorRef(outputs); | VectorRef make_tuple_node = VectorRef(outputs); | ||||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, kPrimReturn)); | |||||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn)); | |||||
| VectorRef return_node = VectorRef({is_return, make_tuple_node}); | VectorRef return_node = VectorRef({is_return, make_tuple_node}); | ||||
| VarPtr fg = std::make_shared<Var>("RootG"); | VarPtr fg = std::make_shared<Var>("RootG"); | ||||
| @@ -116,7 +116,7 @@ AnfNodePtr TfliteLstmCellFusion::GetCondGraphPattern(const PrimitiveVarMapPtr &p | |||||
| auto is_less1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess)); | auto is_less1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess)); | ||||
| auto is_less2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess)); | auto is_less2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess)); | ||||
| auto is_logical_and = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLogicalAnd)); | auto is_logical_and = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLogicalAnd)); | ||||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, kPrimReturn)); | |||||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn)); | |||||
| VectorRef less1_ref = VectorRef({is_less1, is_parameter1, is_parameter2}); | VectorRef less1_ref = VectorRef({is_less1, is_parameter1, is_parameter2}); | ||||
| VectorRef less2_ref = VectorRef({is_less2, is_parameter3, is_parameter4}); | VectorRef less2_ref = VectorRef({is_less2, is_parameter3, is_parameter4}); | ||||
| VectorRef logicaland_ref = VectorRef({is_logical_and, less1_ref, less2_ref}); | VectorRef logicaland_ref = VectorRef({is_logical_and, less1_ref, less2_ref}); | ||||
| @@ -174,11 +174,11 @@ AnfNodePtr TfliteLstmCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &p | |||||
| VectorRef set_item = VectorRef({std::make_shared<Var>("SetItem"), placeholders[3], placeholders[2], output}); | VectorRef set_item = VectorRef({std::make_shared<Var>("SetItem"), placeholders[3], placeholders[2], output}); | ||||
| auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, kPrimMakeTuple)); | |||||
| auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMakeTuple)); | |||||
| std::vector<BaseRef> outputs = {is_make_tuple, add3, placeholders[1], add2, set_item, cell_output, hidden_output}; | std::vector<BaseRef> outputs = {is_make_tuple, add3, placeholders[1], add2, set_item, cell_output, hidden_output}; | ||||
| outputs.insert(outputs.end(), placeholders.begin() + 6, placeholders.end()); | outputs.insert(outputs.end(), placeholders.begin() + 6, placeholders.end()); | ||||
| VectorRef make_tuple_node = VectorRef(outputs); | VectorRef make_tuple_node = VectorRef(outputs); | ||||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, kPrimReturn)); | |||||
| auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn)); | |||||
| VectorRef return_node = VectorRef({is_return, make_tuple_node}); | VectorRef return_node = VectorRef({is_return, make_tuple_node}); | ||||
| VarPtr fg = std::make_shared<Var>("RootG"); | VarPtr fg = std::make_shared<Var>("RootG"); | ||||
| @@ -41,8 +41,8 @@ ParamValueLitePtr NewParamValueLitePtr(lite::Tensor *tensor) { | |||||
| bool IsSpecialType(const CNodePtr &cnode) { | bool IsSpecialType(const CNodePtr &cnode) { | ||||
| if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || CheckPrimitiveType(cnode, prim::kPrimDepend) || | if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || CheckPrimitiveType(cnode, prim::kPrimDepend) || | ||||
| CheckPrimitiveType(cnode, prim::kPrimControlDepend) || CheckPrimitiveType(cnode, kPrimMakeTuple) || | |||||
| CheckPrimitiveType(cnode, kPrimReturn) || CheckPrimitiveType(cnode, std::make_shared<Primitive>("While")) || | |||||
| CheckPrimitiveType(cnode, prim::kPrimControlDepend) || CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || | |||||
| CheckPrimitiveType(cnode, prim::kPrimReturn) || CheckPrimitiveType(cnode, std::make_shared<Primitive>("While")) || | |||||
| CheckPrimitiveType(cnode, std::make_shared<Primitive>("If"))) { | CheckPrimitiveType(cnode, std::make_shared<Primitive>("If"))) { | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -81,7 +81,7 @@ bool WhilePass::Run(const FuncGraphPtr &graph) { | |||||
| // concat body to cond | // concat body to cond | ||||
| std::vector<AnfNodePtr> body_to_cond_inputs{cond_vnode}; | std::vector<AnfNodePtr> body_to_cond_inputs{cond_vnode}; | ||||
| if (CheckPrimitiveType(body_output_cnode, kPrimMakeTuple)) { | |||||
| if (CheckPrimitiveType(body_output_cnode, prim::kPrimMakeTuple)) { | |||||
| for (size_t i = 1; i < body_output_cnode->inputs().size(); ++i) { | for (size_t i = 1; i < body_output_cnode->inputs().size(); ++i) { | ||||
| body_to_cond_inputs.emplace_back(body_output_cnode->input(i)); | body_to_cond_inputs.emplace_back(body_output_cnode->input(i)); | ||||
| } | } | ||||