From: @wangzhe128 Reviewed-by: @hangangqiang,@zhanghaibo5 Signed-off-by: @hangangqiangtags/v1.2.0-rc1
| @@ -41,6 +41,20 @@ inline void Int32ToFloat32(const int32_t *input, float *output, int number) { | |||||
| } | } | ||||
| } | } | ||||
| inline void Int64ToFloat32(const int64_t *input, float *output, int number) { | |||||
| for (int i = 0; i < number; ++i) { | |||||
| output[i] = (float)input[i]; | |||||
| } | |||||
| } | |||||
| #ifdef ENABLE_FP16 | |||||
| inline void Int64ToFp16(const int64_t *input, float16_t *output, int number) { | |||||
| for (int i = 0; i < number; ++i) { | |||||
| output[i] = (float16_t)input[i]; | |||||
| } | |||||
| } | |||||
| #endif | |||||
| inline void Fp16ToFloat32(const uint16_t *input, float *output, int number) { | inline void Fp16ToFloat32(const uint16_t *input, float *output, int number) { | ||||
| for (int i = 0; i < number; ++i) { | for (int i = 0; i < number; ++i) { | ||||
| output[i] = ShortToFloat32(input[i]); | output[i] = ShortToFloat32(input[i]); | ||||
| @@ -82,6 +96,7 @@ inline void BoolToInt32(const bool *input, int32_t *output, int number) { | |||||
| output[i] = (int32_t)input[i]; | output[i] = (int32_t)input[i]; | ||||
| } | } | ||||
| } | } | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -36,7 +36,8 @@ int CastInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o | |||||
| } | } | ||||
| if (input->data_type_ != kNumberTypeBool && input->data_type_ != kNumberTypeUInt8 && | if (input->data_type_ != kNumberTypeBool && input->data_type_ != kNumberTypeUInt8 && | ||||
| input->data_type_ != kNumberTypeInt8 && input->data_type_ != kNumberTypeInt32 && | input->data_type_ != kNumberTypeInt8 && input->data_type_ != kNumberTypeInt32 && | ||||
| input->data_type_ != kNumberTypeFloat32 && input->data_type_ != kNumberTypeFloat16) { | |||||
| input->data_type_ != kNumberTypeInt64 && input->data_type_ != kNumberTypeFloat32 && | |||||
| input->data_type_ != kNumberTypeFloat16) { | |||||
| return NNACL_INPUT_TENSOR_ERROR; | return NNACL_INPUT_TENSOR_ERROR; | ||||
| } | } | ||||
| @@ -121,6 +121,17 @@ int CastFp16CPUKernel::DoCast(int thread_id) { | |||||
| MS_LOG(ERROR) << "Unsupported output data type " << output_data_type; | MS_LOG(ERROR) << "Unsupported output data type " << output_data_type; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } else if (input_data_type == kNumberTypeInt64) { | |||||
| switch (output_data_type) { | |||||
| case kNumberTypeFloat16: | |||||
| Int64ToFloat32(reinterpret_cast<int64_t *>(input->MutableData()) + offset, | |||||
| reinterpret_cast<float *>(output_data) + offset, data_num); | |||||
| break; | |||||
| default: | |||||
| MS_LOG(ERROR) << "Unsupported output data type " << output_data_type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported input data type " << input_data_type; | MS_LOG(ERROR) << "Unsupported input data type " << input_data_type; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -136,4 +147,5 @@ int CastFp16CPUKernel::Run() { | |||||
| } | } | ||||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Cast, LiteKernelCreator<CastFp16CPUKernel>) | REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Cast, LiteKernelCreator<CastFp16CPUKernel>) | ||||
| REG_KERNEL(kCPU, kNumberTypeInt64, PrimitiveType_Cast, LiteKernelCreator<CastFp16CPUKernel>) | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -53,6 +53,37 @@ int CastCPUKernel::ReSize() { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int CastCPUKernel::CastToFp32(lite::Tensor *input, lite::Tensor *output, int offset, int data_num) { | |||||
| auto input_data_type = input->data_type(); | |||||
| auto output_data = output->data_c(); | |||||
| switch (input_data_type) { | |||||
| case kNumberTypeBool: | |||||
| BoolToFloat32(reinterpret_cast<bool *>(input->MutableData()) + offset, | |||||
| reinterpret_cast<float *>(output_data) + offset, data_num); | |||||
| break; | |||||
| case kNumberTypeUInt8: | |||||
| Uint8ToFloat32(reinterpret_cast<uint8_t *>(input->MutableData()) + offset, | |||||
| reinterpret_cast<float *>(output_data) + offset, data_num); | |||||
| break; | |||||
| case kNumberTypeInt32: | |||||
| Int32ToFloat32(reinterpret_cast<int32_t *>(input->MutableData()) + offset, | |||||
| reinterpret_cast<float *>(output_data) + offset, data_num); | |||||
| break; | |||||
| case kNumberTypeFloat16: | |||||
| Fp16ToFloat32(reinterpret_cast<uint16_t *>(input->MutableData()) + offset, | |||||
| reinterpret_cast<float *>(output_data) + offset, data_num); | |||||
| break; | |||||
| case kNumberTypeInt64: | |||||
| Int64ToFloat32(reinterpret_cast<int64_t *>(input->MutableData()) + offset, | |||||
| reinterpret_cast<float *>(output_data) + offset, data_num); | |||||
| break; | |||||
| default: | |||||
| MS_LOG(ERROR) << "Unsupported input data type " << input_data_type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int CastCPUKernel::DoCast(int thread_id) { | int CastCPUKernel::DoCast(int thread_id) { | ||||
| auto input = in_tensors_.at(0); | auto input = in_tensors_.at(0); | ||||
| int data_num = MSMIN(stride_, data_num_ - thread_id * stride_); | int data_num = MSMIN(stride_, data_num_ - thread_id * stride_); | ||||
| @@ -91,32 +122,17 @@ int CastCPUKernel::DoCast(int thread_id) { | |||||
| } else if (input_data_type == kNumberTypeBool && output_data_type == kNumberTypeInt32) { | } else if (input_data_type == kNumberTypeBool && output_data_type == kNumberTypeInt32) { | ||||
| BoolToInt32(reinterpret_cast<bool *>(input->data_c()) + offset, reinterpret_cast<int32_t *>(output_data) + offset, | BoolToInt32(reinterpret_cast<bool *>(input->data_c()) + offset, reinterpret_cast<int32_t *>(output_data) + offset, | ||||
| data_num); | data_num); | ||||
| #ifdef ENABLE_FP16 | |||||
| } else if (input_data_type == kNumberTypeInt64 && output_data_type == kNumberTypeFloat16) { | |||||
| Int64ToFp16(reinterpret_cast<int64_t *>(input->data_c()) + offset, | |||||
| reinterpret_cast<float16_t *>(output_data) + offset, data_num); | |||||
| #endif | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type; | MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } else { | } else { | ||||
| switch (input_data_type) { | |||||
| case kNumberTypeBool: | |||||
| BoolToFloat32(reinterpret_cast<bool *>(input->MutableData()) + offset, | |||||
| reinterpret_cast<float *>(output_data) + offset, data_num); | |||||
| break; | |||||
| case kNumberTypeUInt8: | |||||
| Uint8ToFloat32(reinterpret_cast<uint8_t *>(input->MutableData()) + offset, | |||||
| reinterpret_cast<float *>(output_data) + offset, data_num); | |||||
| break; | |||||
| case kNumberTypeInt32: | |||||
| Int32ToFloat32(reinterpret_cast<int32_t *>(input->MutableData()) + offset, | |||||
| reinterpret_cast<float *>(output_data) + offset, data_num); | |||||
| break; | |||||
| case kNumberTypeFloat16: | |||||
| Fp16ToFloat32(reinterpret_cast<uint16_t *>(input->MutableData()) + offset, | |||||
| reinterpret_cast<float *>(output_data) + offset, data_num); | |||||
| break; | |||||
| default: | |||||
| MS_LOG(ERROR) << "Unsupported input data type " << input_data_type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return CastToFp32(input, output, offset, data_num); | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -132,6 +148,7 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Cast, LiteKernelCreator<CastC | |||||
| REG_KERNEL(kCPU, kNumberTypeUInt8, PrimitiveType_Cast, LiteKernelCreator<CastCPUKernel>) | REG_KERNEL(kCPU, kNumberTypeUInt8, PrimitiveType_Cast, LiteKernelCreator<CastCPUKernel>) | ||||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Cast, LiteKernelCreator<CastCPUKernel>) | REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Cast, LiteKernelCreator<CastCPUKernel>) | ||||
| REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Cast, LiteKernelCreator<CastCPUKernel>) | REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Cast, LiteKernelCreator<CastCPUKernel>) | ||||
| REG_KERNEL(kCPU, kNumberTypeInt64, PrimitiveType_Cast, LiteKernelCreator<CastCPUKernel>) | |||||
| REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Cast, LiteKernelCreator<CastCPUKernel>) | REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Cast, LiteKernelCreator<CastCPUKernel>) | ||||
| #ifndef ENABLE_ARM | #ifndef ENABLE_ARM | ||||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Cast, LiteKernelCreator<CastCPUKernel>) | REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Cast, LiteKernelCreator<CastCPUKernel>) | ||||
| @@ -38,6 +38,7 @@ class CastCPUKernel : public LiteKernel { | |||||
| int DoCast(int thread_id); | int DoCast(int thread_id); | ||||
| private: | private: | ||||
| int CastToFp32(lite::Tensor *input, lite::Tensor *output, int offset, int data_num); | |||||
| int stride_; | int stride_; | ||||
| int data_num_; | int data_num_; | ||||
| }; | }; | ||||
| @@ -233,7 +233,8 @@ if(ENABLE_CONVERTER) | |||||
| ${LITE_DIR}/tools/optimizer/fusion/conv_bn_fusion.cc | ${LITE_DIR}/tools/optimizer/fusion/conv_bn_fusion.cc | ||||
| ${LITE_DIR}/tools/optimizer/fusion/constant_folding_fusion.cc | ${LITE_DIR}/tools/optimizer/fusion/constant_folding_fusion.cc | ||||
| ${LITE_DIR}/tools/optimizer/fusion/quant_dtype_cast_fusion.cc | ${LITE_DIR}/tools/optimizer/fusion/quant_dtype_cast_fusion.cc | ||||
| ${LITE_DIR}/tools/optimizer/fusion/layer_norm_fusion.cc | |||||
| ${LITE_DIR}/tools/optimizer/fusion/tf_norm_fusion.cc | |||||
| ${LITE_DIR}/tools/optimizer/fusion/onnx_layer_norm_fusion.cc | |||||
| ${LITE_DIR}/tools/optimizer/fusion/batchmatmul_fusion.cc | ${LITE_DIR}/tools/optimizer/fusion/batchmatmul_fusion.cc | ||||
| ${LITE_DIR}/tools/optimizer/fusion/sigmoid_mul_fusion.cc | ${LITE_DIR}/tools/optimizer/fusion/sigmoid_mul_fusion.cc | ||||
| ${LITE_DIR}/tools/optimizer/fusion/conv_conv_fusion.cc | ${LITE_DIR}/tools/optimizer/fusion/conv_conv_fusion.cc | ||||
| @@ -38,7 +38,7 @@ ml_video_edit_img_segment 1 | |||||
| ml_video_edit_video_segment_gauss_adaptis_part1 2 | ml_video_edit_video_segment_gauss_adaptis_part1 2 | ||||
| ml_video_edit_generate_filter.pb 1 | ml_video_edit_generate_filter.pb 1 | ||||
| ml_video_edit_img_segment_adaptise.pb 0.5 2 | ml_video_edit_img_segment_adaptise.pb 0.5 2 | ||||
| ml_video_edit_video_segment_gauss_adaptis_part2.pb 3 2 | |||||
| ml_video_edit_video_segment_gauss_adaptis_part2.pb 10 2 | |||||
| ml_video_edit_person_divison_pic 0.5 | ml_video_edit_person_divison_pic 0.5 | ||||
| ml_video_edit_person_divison_video 13 2 | ml_video_edit_person_divison_video 13 2 | ||||
| ml_video_edit_imitate_filter.onnx 230 | ml_video_edit_imitate_filter.onnx 230 | ||||
| @@ -51,6 +51,8 @@ static const std::vector<schema::PrimitiveType> nhwcOpList = {schema::PrimitiveT | |||||
| schema::PrimitiveType_SpaceToBatch, | schema::PrimitiveType_SpaceToBatch, | ||||
| schema::PrimitiveType_SpaceToBatchND}; | schema::PrimitiveType_SpaceToBatchND}; | ||||
| static const std::vector<schema::PrimitiveType> nchwOpList = {schema::PrimitiveType_InstanceNorm}; | |||||
| static const std::vector<schema::PrimitiveType> nhwcOpAllInputList = { | static const std::vector<schema::PrimitiveType> nhwcOpAllInputList = { | ||||
| schema::PrimitiveType_AvgPoolGrad, schema::PrimitiveType_MaxPoolGrad, | schema::PrimitiveType_AvgPoolGrad, schema::PrimitiveType_MaxPoolGrad, | ||||
| schema::PrimitiveType_ActivationGrad, schema::PrimitiveType_Conv2DBackpropFilterFusion, | schema::PrimitiveType_ActivationGrad, schema::PrimitiveType_Conv2DBackpropFilterFusion, | ||||
| @@ -153,6 +155,8 @@ std::vector<schema::PrimitiveType> Getfp32FullOpList() { return fp32FullOpList; | |||||
| std::vector<schema::PrimitiveType> GetNhwcOpList() { return nhwcOpList; } | std::vector<schema::PrimitiveType> GetNhwcOpList() { return nhwcOpList; } | ||||
| std::vector<schema::PrimitiveType> GetNchwOpList() { return nchwOpList; } | |||||
| std::unordered_map<schema::PrimitiveType, std::vector<int>> GetExtNhwcIndexes() { return extNhwcInsertIndex; } | std::unordered_map<schema::PrimitiveType, std::vector<int>> GetExtNhwcIndexes() { return extNhwcInsertIndex; } | ||||
| std::vector<schema::PrimitiveType> GetNhwcAllInputOpList() { return nhwcOpAllInputList; } | std::vector<schema::PrimitiveType> GetNhwcAllInputOpList() { return nhwcOpAllInputList; } | ||||
| @@ -60,6 +60,8 @@ std::vector<schema::PrimitiveType> GetInsertOpList(); | |||||
| std::vector<schema::PrimitiveType> GetNhwcOpList(); | std::vector<schema::PrimitiveType> GetNhwcOpList(); | ||||
| std::vector<schema::PrimitiveType> GetNchwOpList(); | |||||
| std::vector<schema::PrimitiveType> GetNhwcAllInputOpList(); | std::vector<schema::PrimitiveType> GetNhwcAllInputOpList(); | ||||
| std::unordered_map<schema::PrimitiveType, std::vector<int>> GetExtNhwcIndexes(); | std::unordered_map<schema::PrimitiveType, std::vector<int>> GetExtNhwcIndexes(); | ||||
| @@ -44,7 +44,8 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| ../optimizer/fusion/conv_tuplegetitem_fusion.cc | ../optimizer/fusion/conv_tuplegetitem_fusion.cc | ||||
| ../optimizer/fusion/constant_folding_fusion.cc | ../optimizer/fusion/constant_folding_fusion.cc | ||||
| ../optimizer/fusion/quant_dtype_cast_fusion.cc | ../optimizer/fusion/quant_dtype_cast_fusion.cc | ||||
| ../optimizer/fusion/layer_norm_fusion.cc | |||||
| ../optimizer/fusion/tf_norm_fusion.cc | |||||
| ../optimizer/fusion/onnx_layer_norm_fusion.cc | |||||
| ../optimizer/fusion/batchmatmul_fusion.cc | ../optimizer/fusion/batchmatmul_fusion.cc | ||||
| ../optimizer/fusion/sigmoid_mul_fusion.cc | ../optimizer/fusion/sigmoid_mul_fusion.cc | ||||
| ../optimizer/fusion/conv_conv_fusion.cc | ../optimizer/fusion/conv_conv_fusion.cc | ||||
| @@ -27,7 +27,8 @@ | |||||
| #include "tools/optimizer/fusion/conv_bn_fusion.h" | #include "tools/optimizer/fusion/conv_bn_fusion.h" | ||||
| #include "tools/optimizer/fusion/conv_tuplegetitem_fusion.h" | #include "tools/optimizer/fusion/conv_tuplegetitem_fusion.h" | ||||
| #include "tools/optimizer/fusion/constant_folding_fusion.h" | #include "tools/optimizer/fusion/constant_folding_fusion.h" | ||||
| #include "tools/optimizer/fusion/layer_norm_fusion.h" | |||||
| #include "tools/optimizer/fusion/tf_norm_fusion.h" | |||||
| #include "tools/optimizer/fusion/onnx_layer_norm_fusion.h" | |||||
| #include "tools/optimizer/fusion/batchmatmul_fusion.h" | #include "tools/optimizer/fusion/batchmatmul_fusion.h" | ||||
| #include "tools/optimizer/fusion/sigmoid_mul_fusion.h" | #include "tools/optimizer/fusion/sigmoid_mul_fusion.h" | ||||
| #include "tools/optimizer/fusion/conv_conv_fusion.h" | #include "tools/optimizer/fusion/conv_conv_fusion.h" | ||||
| @@ -77,7 +78,8 @@ int AnfTransform::AddFusionPass(const std::shared_ptr<opt::GraphOptimizer> &opti | |||||
| auto conv_scale_pass = std::make_shared<opt::ConvScaleFusion>(); | auto conv_scale_pass = std::make_shared<opt::ConvScaleFusion>(); | ||||
| conv_scale_pass->SetFmkType(config->fmk); | conv_scale_pass->SetFmkType(config->fmk); | ||||
| fusion_pm->AddPass(conv_scale_pass); | fusion_pm->AddPass(conv_scale_pass); | ||||
| fusion_pm->AddPass(std::make_shared<opt::LayerNormFusion>()); | |||||
| fusion_pm->AddPass(std::make_shared<opt::TfNormFusion>()); | |||||
| fusion_pm->AddPass(std::make_shared<opt::OnnxLayerNormFusion>()); | |||||
| fusion_pm->AddPass(std::make_shared<opt::BatchMatMulFusion>()); | fusion_pm->AddPass(std::make_shared<opt::BatchMatMulFusion>()); | ||||
| fusion_pm->AddPass(std::make_shared<opt::SigmoidMulFusion>()); | fusion_pm->AddPass(std::make_shared<opt::SigmoidMulFusion>()); | ||||
| fusion_pm->AddPass(std::make_shared<opt::ConvActivationFusion>()); | fusion_pm->AddPass(std::make_shared<opt::ConvActivationFusion>()); | ||||
| @@ -48,7 +48,12 @@ STATUS FormatTransPass::Run(schema::MetaGraphT *graph) { | |||||
| STATUS FormatTransPass::GetInsertFormatTrans(const schema::CNodeT &node, FormatTransNodeType *beforeNodeType, | STATUS FormatTransPass::GetInsertFormatTrans(const schema::CNodeT &node, FormatTransNodeType *beforeNodeType, | ||||
| FormatTransNodeType *afterNodeType) { | FormatTransNodeType *afterNodeType) { | ||||
| if (fmk_type_ == converter::FmkType_TFLITE) { // inference by nhwc | if (fmk_type_ == converter::FmkType_TFLITE) { // inference by nhwc | ||||
| return RET_NO_CHANGE; | |||||
| if (!IsContain(GetNchwOpList(), GetCNodeTType(node))) { | |||||
| return RET_NO_CHANGE; | |||||
| } | |||||
| *beforeNodeType = kNHWC2NCHW; | |||||
| *afterNodeType = kNCHW2NHWC; | |||||
| return RET_OK; | |||||
| } else if (fmk_type_ == converter::FmkType_CAFFE || fmk_type_ == converter::FmkType_MS || | } else if (fmk_type_ == converter::FmkType_CAFFE || fmk_type_ == converter::FmkType_MS || | ||||
| fmk_type_ == converter::FmkType_ONNX) { | fmk_type_ == converter::FmkType_ONNX) { | ||||
| if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) { | if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) { | ||||
| @@ -63,6 +68,11 @@ STATUS FormatTransPass::GetInsertFormatTrans(const schema::CNodeT &node, FormatT | |||||
| *afterNodeType = kNHWC2NCHW; | *afterNodeType = kNHWC2NCHW; | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| if (IsContain(GetNchwOpList(), GetCNodeTType(node))) { | |||||
| *beforeNodeType = kNHWC2NCHW; | |||||
| *afterNodeType = kNCHW2NHWC; | |||||
| return RET_OK; | |||||
| } | |||||
| return RET_NO_CHANGE; | return RET_NO_CHANGE; | ||||
| } | } | ||||
| MS_LOG(ERROR) << "Unsupported fmk: " << fmk_type_; | MS_LOG(ERROR) << "Unsupported fmk: " << fmk_type_; | ||||
| @@ -1,65 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| #include "utils/utils.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class LayerNormFusion : public PatternProcessPass { | |||||
| public: | |||||
| explicit LayerNormFusion(const std::string &name = "layer_norm_fusion", bool multigraph = true) | |||||
| : PatternProcessPass(name, multigraph) { | |||||
| input_ = std::make_shared<Var>(); | |||||
| mean1_ = std::make_shared<Var>(); | |||||
| mean1_axes_ = std::make_shared<Var>(); | |||||
| mean2_ = std::make_shared<Var>(); | |||||
| mean2_axes_ = std::make_shared<Var>(); | |||||
| gamma_ = std::make_shared<Var>(); | |||||
| beta_ = std::make_shared<Var>(); | |||||
| epsilon_ = std::make_shared<Var>(); | |||||
| } | |||||
| ~LayerNormFusion() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| private: | |||||
| bool GetAxis(const CNodePtr &input_cnode, const std::vector<int> &mean_axes, const std::vector<int> ¶ms_shape, | |||||
| int *begin_norm_axis, int *begin_params_axis) const; | |||||
| bool CheckPattern(const EquivPtr &equiv, float *epsilon, int *begin_norm_axis, int *begin_params_axis) const; | |||||
| CNodePtr CreateLayerNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, float epsilon, | |||||
| int begin_norm_axis, int begin_params_axis) const; | |||||
| VarPtr input_ = nullptr; | |||||
| VarPtr mean1_ = nullptr; | |||||
| VarPtr mean1_axes_ = nullptr; | |||||
| VarPtr mean2_ = nullptr; | |||||
| VarPtr mean2_axes_ = nullptr; | |||||
| VarPtr gamma_ = nullptr; | |||||
| VarPtr beta_ = nullptr; | |||||
| VarPtr epsilon_ = nullptr; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_ | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/optimizer/fusion/onnx_layer_norm_fusion.h" | |||||
| #include <memory> | |||||
| #include "ops/rsqrt.h" | |||||
| #include "tools/optimizer/common/gllo_utils.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| const BaseRef OnnxLayerNormFusion::DefinePattern() const { | |||||
| VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_}); | |||||
| VectorRef sub1_ref = VectorRef({std::make_shared<CondVar>(IsSubNode), input_, mean1_ref}); | |||||
| VectorRef sub2_ref = VectorRef({std::make_shared<CondVar>(IsSubNode), input_, mean1_ref}); | |||||
| VectorRef pow_ref = VectorRef({std::make_shared<CondVar>(IsPowNode), sub2_ref, std::make_shared<Var>()}); | |||||
| VectorRef mean2_ref = VectorRef({mean2_, pow_ref, mean2_axes_}); | |||||
| VectorRef add1_ref = VectorRef({std::make_shared<CondVar>(IsAddNode), mean2_ref, epsilon_}); | |||||
| VectorRef sqrt_ref = VectorRef({std::make_shared<CondVar>(IsSqrtNode), add1_ref}); | |||||
| VectorRef div_ref = VectorRef({std::make_shared<CondVar>(IsDivNode), sub1_ref, sqrt_ref}); | |||||
| VectorRef mul_ref = VectorRef({std::make_shared<CondVar>(IsMulNode), gamma_, div_ref}); | |||||
| VectorRef add2_ref = VectorRef({std::make_shared<CondVar>(IsAddNode), mul_ref, beta_}); | |||||
| return add2_ref; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,60 @@ | |||||
| /** | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_LAYER_NORM_FUSION_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_LAYER_NORM_FUSION_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "tools/optimizer/fusion/tf_norm_fusion.h" | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| #include "utils/utils.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class OnnxLayerNormFusion : public TfNormFusion { | |||||
| public: | |||||
| explicit OnnxLayerNormFusion(const std::string &name = "onnx_layer_norm_fusion", bool multigraph = true) | |||||
| : TfNormFusion(name, multigraph) {} | |||||
| ~OnnxLayerNormFusion() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| }; | |||||
| inline bool IsPowNode(const BaseRef &n) { | |||||
| if (utils::isa<AnfNodePtr>(n)) { | |||||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimPowFusion); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| inline bool IsSqrtNode(const BaseRef &n) { | |||||
| if (utils::isa<AnfNodePtr>(n)) { | |||||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSqrt); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| inline bool IsDivNode(const BaseRef &n) { | |||||
| if (utils::isa<AnfNodePtr>(n)) { | |||||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimDiv) || | |||||
| CheckPrimitiveType(utils::cast<AnfNodePtr>(n), std::make_shared<Primitive>("DivFusion")); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_LAYER_NORM_FUSION_H_ | |||||
| @@ -13,11 +13,12 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "tools/optimizer/fusion/layer_norm_fusion.h" | |||||
| #include "tools/optimizer/fusion/tf_norm_fusion.h" | |||||
| #include <memory> | #include <memory> | ||||
| #include "ops/fusion/layer_norm_fusion.h" | #include "ops/fusion/layer_norm_fusion.h" | ||||
| #include "ops/fusion/reduce_fusion.h" | #include "ops/fusion/reduce_fusion.h" | ||||
| #include "ops/rsqrt.h" | #include "ops/rsqrt.h" | ||||
| #include "mindspore/core/ops/instance_norm.h" | |||||
| #include "src/param_value_lite.h" | #include "src/param_value_lite.h" | ||||
| #include "utils/utils.h" | #include "utils/utils.h" | ||||
| #include "tools/optimizer/common/gllo_utils.h" | #include "tools/optimizer/common/gllo_utils.h" | ||||
| @@ -26,41 +27,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | namespace { | ||||
| bool IsAddNode(const BaseRef &n) { | |||||
| if (utils::isa<AnfNodePtr>(n)) { | |||||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimAddFusion); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool IsSquaredDifferenceNode(const BaseRef &n) { | |||||
| if (utils::isa<AnfNodePtr>(n)) { | |||||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSquaredDifference); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool IsRsqrtNode(const BaseRef &n) { | |||||
| if (utils::isa<AnfNodePtr>(n)) { | |||||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimRsqrt); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool IsMulNode(const BaseRef &n) { | |||||
| if (utils::isa<AnfNodePtr>(n)) { | |||||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimMulFusion); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool IsSubNode(const BaseRef &n) { | |||||
| if (utils::isa<AnfNodePtr>(n)) { | |||||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSubFusion); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| lite::STATUS GetReduceAxes(const BaseRef &n, std::vector<int> *axes) { | lite::STATUS GetReduceAxes(const BaseRef &n, std::vector<int> *axes) { | ||||
| MS_ASSERT(node != nullptr); | MS_ASSERT(node != nullptr); | ||||
| if (utils::isa<ParameterPtr>(n)) { | if (utils::isa<ParameterPtr>(n)) { | ||||
| @@ -106,7 +72,7 @@ bool IsReduceNode(const EquivPtr &equiv, const VarPtr &input_prim, const VarPtr | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| const BaseRef LayerNormFusion::DefinePattern() const { | |||||
| const BaseRef TfNormFusion::DefinePattern() const { | |||||
| VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_}); | VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_}); | ||||
| auto squared_diffference1 = std::make_shared<CondVar>(IsSquaredDifferenceNode); | auto squared_diffference1 = std::make_shared<CondVar>(IsSquaredDifferenceNode); | ||||
| VectorRef squared_diffference1_ref = VectorRef({squared_diffference1, input_, mean1_ref}); | VectorRef squared_diffference1_ref = VectorRef({squared_diffference1, input_, mean1_ref}); | ||||
| @@ -128,13 +94,26 @@ const BaseRef LayerNormFusion::DefinePattern() const { | |||||
| return add2_ref; | return add2_ref; | ||||
| } | } | ||||
| CNodePtr LayerNormFusion::CreateLayerNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, float epsilon, | |||||
| int begin_norm_axis, int begin_params_axis) const { | |||||
| CNodePtr TfNormFusion::CreateNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, | |||||
| const schema::PrimitiveType type, float epsilon, int begin_norm_axis, | |||||
| int begin_params_axis) const { | |||||
| MS_ASSERT(func_graph != nullptr); | MS_ASSERT(func_graph != nullptr); | ||||
| MS_ASSERT(equiv != nullptr); | MS_ASSERT(equiv != nullptr); | ||||
| auto layer_norm_primitive = std::make_shared<ops::LayerNormFusion>(); | |||||
| layer_norm_primitive->Init(begin_norm_axis, begin_params_axis, epsilon); | |||||
| auto value_node = NewValueNode(layer_norm_primitive); | |||||
| auto norm_primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| norm_primitive->value.type = type; | |||||
| PrimitiveCPtr primitive = nullptr; | |||||
| if (type == schema::PrimitiveType_LayerNormFusion) { | |||||
| auto layer_norm_primitive = std::make_shared<ops::LayerNormFusion>(); | |||||
| layer_norm_primitive->Init(begin_norm_axis, begin_params_axis, epsilon, true); | |||||
| primitive = layer_norm_primitive; | |||||
| } else if (type == schema::PrimitiveType_InstanceNorm) { | |||||
| auto instance_norm_primitive = std::make_shared<ops::InstanceNorm>(); | |||||
| instance_norm_primitive->Init(epsilon); | |||||
| primitive = instance_norm_primitive; | |||||
| } else { | |||||
| return nullptr; | |||||
| } | |||||
| auto value_node = NewValueNode(primitive); | |||||
| std::vector<AnfNodePtr> new_node_inputs = {value_node}; | std::vector<AnfNodePtr> new_node_inputs = {value_node}; | ||||
| auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_]); | auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_]); | ||||
| MS_ASSERT(input_node != nullptr); | MS_ASSERT(input_node != nullptr); | ||||
| @@ -149,10 +128,11 @@ CNodePtr LayerNormFusion::CreateLayerNormNode(const FuncGraphPtr &func_graph, co | |||||
| return new_node; | return new_node; | ||||
| } | } | ||||
| bool LayerNormFusion::GetAxis(const CNodePtr &input_cnode, const std::vector<int> &mean_axes, | |||||
| const std::vector<int> ¶ms_shape, int *begin_norm_axis, | |||||
| int *begin_params_axis) const { | |||||
| bool TfNormFusion::GetNormTypeAndAxis(const CNodePtr &input_cnode, const std::vector<int> &mean_axes, | |||||
| const std::vector<int> ¶ms_shape, schema::PrimitiveType *type, | |||||
| int *begin_norm_axis, int *begin_params_axis) const { | |||||
| MS_ASSERT(input_node != nullptr); | MS_ASSERT(input_node != nullptr); | ||||
| MS_ASSERT(type != nullptr); | |||||
| MS_ASSERT(begin_norm_axis != nullptr); | MS_ASSERT(begin_norm_axis != nullptr); | ||||
| MS_ASSERT(begin_params_axis != nullptr); | MS_ASSERT(begin_params_axis != nullptr); | ||||
| auto abstract = input_cnode->abstract(); | auto abstract = input_cnode->abstract(); | ||||
| @@ -170,30 +150,44 @@ bool LayerNormFusion::GetAxis(const CNodePtr &input_cnode, const std::vector<int | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto shape = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape(); | auto shape = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape(); | ||||
| if (mean_axes.back() + 1 != static_cast<int>(shape.size())) { | |||||
| MS_LOG(DEBUG) << "mean node is not reduce to last axis"; | |||||
| return false; | |||||
| } | |||||
| for (size_t i = 1; i < mean_axes.size(); ++i) { | for (size_t i = 1; i < mean_axes.size(); ++i) { | ||||
| if (mean_axes[i] != mean_axes[i - 1] + 1) { | if (mean_axes[i] != mean_axes[i - 1] + 1) { | ||||
| MS_LOG(DEBUG) << "mean axes is not continuous"; | MS_LOG(DEBUG) << "mean axes is not continuous"; | ||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| if (shape.size() == 4 && mean_axes.size() == 2 && mean_axes[0] == 1 && mean_axes[1] == 2) { | |||||
| if (params_shape.size() == 1 && params_shape.back() == shape.back()) { | |||||
| *type = schema::PrimitiveType_InstanceNorm; | |||||
| return true; | |||||
| } | |||||
| } | |||||
| if (mean_axes.back() >= 0 && mean_axes.back() + 1 != static_cast<int>(shape.size())) { | |||||
| MS_LOG(DEBUG) << "mean node is not reduce to last axis"; | |||||
| return false; | |||||
| } | |||||
| // there is no need to check params_shape | // there is no need to check params_shape | ||||
| *begin_norm_axis = mean_axes.front(); | *begin_norm_axis = mean_axes.front(); | ||||
| *begin_params_axis = static_cast<int>(shape.size()) - static_cast<int>(params_shape.size()); | |||||
| if (*begin_params_axis < 0) { | |||||
| MS_LOG(DEBUG) << "LayerNorm begin_params_axis illegal, not fuse"; | |||||
| return false; | |||||
| if (*begin_norm_axis >= 0) { | |||||
| *begin_params_axis = static_cast<int>(shape.size()) - static_cast<int>(params_shape.size()); | |||||
| if (*begin_params_axis < 0) { | |||||
| MS_LOG(DEBUG) << "LayerNorm begin_params_axis illegal, not fuse"; | |||||
| return false; | |||||
| } | |||||
| } else { | |||||
| *begin_params_axis = -static_cast<int>(params_shape.size()); | |||||
| } | } | ||||
| *type = schema::PrimitiveType_LayerNormFusion; | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool LayerNormFusion::CheckPattern(const EquivPtr &equiv, float *epsilon, int *begin_norm_axis, | |||||
| int *begin_params_axis) const { | |||||
| bool TfNormFusion::CheckPattern(const EquivPtr &equiv, schema::PrimitiveType *type, float *epsilon, | |||||
| int *begin_norm_axis, int *begin_params_axis) const { | |||||
| MS_ASSERT(equiv != nullptr); | MS_ASSERT(equiv != nullptr); | ||||
| MS_ASSERT(epsilon != nullptr); | MS_ASSERT(epsilon != nullptr); | ||||
| MS_ASSERT(type != nullptr); | |||||
| MS_ASSERT(begin_norm_axis != nullptr); | MS_ASSERT(begin_norm_axis != nullptr); | ||||
| MS_ASSERT(begin_params_axis != nullptr); | MS_ASSERT(begin_params_axis != nullptr); | ||||
| // beta | // beta | ||||
| @@ -243,9 +237,6 @@ bool LayerNormFusion::CheckPattern(const EquivPtr &equiv, float *epsilon, int *b | |||||
| if (mean1_axes != mean2_axes) { | if (mean1_axes != mean2_axes) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| if (mean1_axes.size() != gamma_shape.size() || mean1_axes.size() != beta_shape.size()) { | |||||
| return false; | |||||
| } | |||||
| if (gamma_shape != beta_shape) { | if (gamma_shape != beta_shape) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -254,14 +245,14 @@ bool LayerNormFusion::CheckPattern(const EquivPtr &equiv, float *epsilon, int *b | |||||
| } else { | } else { | ||||
| return false; | return false; | ||||
| } | } | ||||
| if (!GetAxis(input_cnode, mean1_axes, gamma_shape, begin_norm_axis, begin_params_axis)) { | |||||
| if (!GetNormTypeAndAxis(input_cnode, mean1_axes, gamma_shape, type, begin_norm_axis, begin_params_axis)) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| const AnfNodePtr LayerNormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const EquivPtr &equiv) const { | |||||
| const AnfNodePtr TfNormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const EquivPtr &equiv) const { | |||||
| MS_ASSERT(func_graph != nullptr); | MS_ASSERT(func_graph != nullptr); | ||||
| MS_ASSERT(node != nullptr); | MS_ASSERT(node != nullptr); | ||||
| MS_ASSERT(equiv != nullptr); | MS_ASSERT(equiv != nullptr); | ||||
| @@ -273,14 +264,24 @@ const AnfNodePtr LayerNormFusion::Process(const FuncGraphPtr &func_graph, const | |||||
| float epsilon = 0.0f; | float epsilon = 0.0f; | ||||
| int begin_norm_axis = 0; | int begin_norm_axis = 0; | ||||
| int begin_params_axis = 0; | int begin_params_axis = 0; | ||||
| if (!CheckPattern(equiv, &epsilon, &begin_norm_axis, &begin_params_axis)) { | |||||
| schema::PrimitiveType type = schema::PrimitiveType_NONE; | |||||
| if (!CheckPattern(equiv, &type, &epsilon, &begin_norm_axis, &begin_params_axis)) { | |||||
| return nullptr; | |||||
| } | |||||
| auto norm_cnode = CreateNormNode(func_graph, equiv, type, epsilon, begin_norm_axis, begin_params_axis); | |||||
| if (norm_cnode == nullptr) { | |||||
| MS_LOG(DEBUG) << "create norm cnode failed"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto layer_norm_cnode = CreateLayerNormNode(func_graph, equiv, epsilon, begin_norm_axis, begin_params_axis); | |||||
| layer_norm_cnode->set_abstract(add2_cnode->abstract()->Clone()); | |||||
| layer_norm_cnode->set_fullname_with_scope("layer_norm_" + add2_cnode->fullname_with_scope()); | |||||
| MS_LOG(INFO) << "layernorm node:" << layer_norm_cnode->fullname_with_scope() << " fusion success"; | |||||
| return layer_norm_cnode; | |||||
| norm_cnode->set_abstract(add2_cnode->abstract()->Clone()); | |||||
| if (type == schema::PrimitiveType_LayerNormFusion) { | |||||
| norm_cnode->set_fullname_with_scope("layer_norm_" + add2_cnode->fullname_with_scope()); | |||||
| MS_LOG(INFO) << "layer_norm node:" << norm_cnode->fullname_with_scope() << " fusion success"; | |||||
| } else if (type == schema::PrimitiveType_InstanceNorm) { | |||||
| norm_cnode->set_fullname_with_scope("instance_norm_" + add2_cnode->fullname_with_scope()); | |||||
| MS_LOG(INFO) << "instance_norm node:" << norm_cnode->fullname_with_scope() << " fusion success"; | |||||
| } | |||||
| return norm_cnode; | |||||
| } | } | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,107 @@ | |||||
| /** | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "schema/inner/model_generated.h" | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| #include "utils/utils.h" | |||||
| #include "tools/optimizer/common/gllo_utils.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| /// fuse layer_norm, instance_norm into one operator | |||||
| class TfNormFusion : public PatternProcessPass { | |||||
| public: | |||||
| explicit TfNormFusion(const std::string &name = "tf_norm_fusion", bool multigraph = true) | |||||
| : PatternProcessPass(name, multigraph) { | |||||
| input_ = std::make_shared<Var>(); | |||||
| mean1_ = std::make_shared<Var>(); | |||||
| mean1_axes_ = std::make_shared<Var>(); | |||||
| mean2_ = std::make_shared<Var>(); | |||||
| mean2_axes_ = std::make_shared<Var>(); | |||||
| gamma_ = std::make_shared<Var>(); | |||||
| beta_ = std::make_shared<Var>(); | |||||
| epsilon_ = std::make_shared<Var>(); | |||||
| } | |||||
| ~TfNormFusion() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| private: | |||||
| bool GetNormTypeAndAxis(const CNodePtr &input_cnode, const std::vector<int> &mean_axes, | |||||
| const std::vector<int> ¶ms_shape, schema::PrimitiveType *type, int *begin_norm_axis, | |||||
| int *begin_params_axis) const; | |||||
| bool CheckPattern(const EquivPtr &equiv, schema::PrimitiveType *type, float *epsilon, int *begin_norm_axis, | |||||
| int *begin_params_axis) const; | |||||
| CNodePtr CreateNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const schema::PrimitiveType type, | |||||
| float epsilon, int begin_norm_axis, int begin_params_axis) const; | |||||
| protected: | |||||
| VarPtr input_ = nullptr; | |||||
| VarPtr mean1_ = nullptr; | |||||
| VarPtr mean1_axes_ = nullptr; | |||||
| VarPtr mean2_ = nullptr; | |||||
| VarPtr mean2_axes_ = nullptr; | |||||
| VarPtr gamma_ = nullptr; | |||||
| VarPtr beta_ = nullptr; | |||||
| VarPtr epsilon_ = nullptr; | |||||
| }; | |||||
| inline bool IsAddNode(const BaseRef &n) { | |||||
| if (utils::isa<AnfNodePtr>(n)) { | |||||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimAddFusion); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| inline bool IsSquaredDifferenceNode(const BaseRef &n) { | |||||
| if (utils::isa<AnfNodePtr>(n)) { | |||||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSquaredDifference); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| inline bool IsRsqrtNode(const BaseRef &n) { | |||||
| if (utils::isa<AnfNodePtr>(n)) { | |||||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimRsqrt); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| inline bool IsMulNode(const BaseRef &n) { | |||||
| if (utils::isa<AnfNodePtr>(n)) { | |||||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimMulFusion); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| inline bool IsSubNode(const BaseRef &n) { | |||||
| if (utils::isa<AnfNodePtr>(n)) { | |||||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSubFusion); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_ | |||||