From 70af1d1615f27d3ce80f111e691e27930fc7a6f4 Mon Sep 17 00:00:00 2001 From: wangzhe Date: Mon, 22 Feb 2021 14:25:28 +0800 Subject: [PATCH] instance norm fusion & support onnx layer norm --- mindspore/lite/nnacl/base/cast_base.h | 15 ++ mindspore/lite/nnacl/infer/cast_infer.c | 3 +- .../src/runtime/kernel/arm/fp16/cast_fp16.cc | 12 ++ .../src/runtime/kernel/arm/fp32/cast_fp32.cc | 59 +++++--- .../src/runtime/kernel/arm/fp32/cast_fp32.h | 1 + mindspore/lite/test/CMakeLists.txt | 3 +- mindspore/lite/test/models_npu.cfg | 2 +- mindspore/lite/tools/common/node_util.cc | 4 + mindspore/lite/tools/common/node_util.h | 2 + mindspore/lite/tools/converter/CMakeLists.txt | 3 +- .../lite/tools/converter/anf_transform.cc | 6 +- .../graph/format_trans_pass.cc | 12 +- .../optimizer/fusion/layer_norm_fusion.h | 65 --------- .../fusion/onnx_layer_norm_fusion.cc | 37 +++++ .../optimizer/fusion/onnx_layer_norm_fusion.h | 60 ++++++++ ...layer_norm_fusion.cc => tf_norm_fusion.cc} | 135 +++++++++--------- .../tools/optimizer/fusion/tf_norm_fusion.h | 107 ++++++++++++++ 17 files changed, 366 insertions(+), 160 deletions(-) delete mode 100644 mindspore/lite/tools/optimizer/fusion/layer_norm_fusion.h create mode 100644 mindspore/lite/tools/optimizer/fusion/onnx_layer_norm_fusion.cc create mode 100644 mindspore/lite/tools/optimizer/fusion/onnx_layer_norm_fusion.h rename mindspore/lite/tools/optimizer/fusion/{layer_norm_fusion.cc => tf_norm_fusion.cc} (69%) create mode 100644 mindspore/lite/tools/optimizer/fusion/tf_norm_fusion.h diff --git a/mindspore/lite/nnacl/base/cast_base.h b/mindspore/lite/nnacl/base/cast_base.h index a769cfd7d9..757007b0b3 100644 --- a/mindspore/lite/nnacl/base/cast_base.h +++ b/mindspore/lite/nnacl/base/cast_base.h @@ -41,6 +41,20 @@ inline void Int32ToFloat32(const int32_t *input, float *output, int number) { } } +inline void Int64ToFloat32(const int64_t *input, float *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float)input[i]; + } +} + +#ifdef ENABLE_FP16 +inline void Int64ToFp16(const int64_t *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} +#endif + inline void Fp16ToFloat32(const uint16_t *input, float *output, int number) { for (int i = 0; i < number; ++i) { output[i] = ShortToFloat32(input[i]); @@ -82,6 +96,7 @@ inline void BoolToInt32(const bool *input, int32_t *output, int number) { output[i] = (int32_t)input[i]; } } + #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/infer/cast_infer.c b/mindspore/lite/nnacl/infer/cast_infer.c index 2474d5cb0a..beb01ae077 100644 --- a/mindspore/lite/nnacl/infer/cast_infer.c +++ b/mindspore/lite/nnacl/infer/cast_infer.c @@ -36,7 +36,8 @@ int CastInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o } if (input->data_type_ != kNumberTypeBool && input->data_type_ != kNumberTypeUInt8 && input->data_type_ != kNumberTypeInt8 && input->data_type_ != kNumberTypeInt32 && - input->data_type_ != kNumberTypeFloat32 && input->data_type_ != kNumberTypeFloat16) { + input->data_type_ != kNumberTypeInt64 && input->data_type_ != kNumberTypeFloat32 && + input->data_type_ != kNumberTypeFloat16) { return NNACL_INPUT_TENSOR_ERROR; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/cast_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/cast_fp16.cc index b24f3abbe2..bac5d3391c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/cast_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/cast_fp16.cc @@ -121,6 +121,17 @@ int CastFp16CPUKernel::DoCast(int thread_id) { MS_LOG(ERROR) << "Unsupported output data type " << output_data_type; return RET_ERROR; } + } else if (input_data_type == kNumberTypeInt64) { + switch (output_data_type) { + case kNumberTypeFloat16: + Int64ToFloat32(reinterpret_cast(input->MutableData()) + offset, + reinterpret_cast(output_data) + offset, data_num); + break; + default: + MS_LOG(ERROR) << "Unsupported output data type " << output_data_type; + return RET_ERROR; + } + } else { MS_LOG(ERROR) << "Unsupported input data type " << input_data_type; return RET_ERROR; @@ -136,4 +147,5 @@ int CastFp16CPUKernel::Run() { } REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Cast, LiteKernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt64, PrimitiveType_Cast, LiteKernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/cast_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/cast_fp32.cc index e2c35e8b6f..9fae4807a9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/cast_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/cast_fp32.cc @@ -53,6 +53,37 @@ int CastCPUKernel::ReSize() { return RET_OK; } +int CastCPUKernel::CastToFp32(lite::Tensor *input, lite::Tensor *output, int offset, int data_num) { + auto input_data_type = input->data_type(); + auto output_data = output->data_c(); + switch (input_data_type) { + case kNumberTypeBool: + BoolToFloat32(reinterpret_cast(input->MutableData()) + offset, + reinterpret_cast(output_data) + offset, data_num); + break; + case kNumberTypeUInt8: + Uint8ToFloat32(reinterpret_cast(input->MutableData()) + offset, + reinterpret_cast(output_data) + offset, data_num); + break; + case kNumberTypeInt32: + Int32ToFloat32(reinterpret_cast(input->MutableData()) + offset, + reinterpret_cast(output_data) + offset, data_num); + break; + case kNumberTypeFloat16: + Fp16ToFloat32(reinterpret_cast(input->MutableData()) + offset, + reinterpret_cast(output_data) + offset, data_num); + break; + case kNumberTypeInt64: + Int64ToFloat32(reinterpret_cast(input->MutableData()) + offset, + reinterpret_cast(output_data) + offset, data_num); + break; + default: + MS_LOG(ERROR) << "Unsupported input data type " << input_data_type; + return RET_ERROR; + } + return RET_OK; +} + int CastCPUKernel::DoCast(int thread_id) { auto input = in_tensors_.at(0); int data_num = MSMIN(stride_, data_num_ - thread_id * stride_); @@ -91,32 +122,17 @@ int CastCPUKernel::DoCast(int thread_id) { } else if (input_data_type == kNumberTypeBool && output_data_type == kNumberTypeInt32) { BoolToInt32(reinterpret_cast(input->data_c()) + offset, reinterpret_cast(output_data) + offset, data_num); +#ifdef ENABLE_FP16 + } else if (input_data_type == kNumberTypeInt64 && output_data_type == kNumberTypeFloat16) { + Int64ToFp16(reinterpret_cast(input->data_c()) + offset, + reinterpret_cast(output_data) + offset, data_num); +#endif } else { MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type; return RET_ERROR; } } else { - switch (input_data_type) { - case kNumberTypeBool: - BoolToFloat32(reinterpret_cast(input->MutableData()) + offset, - reinterpret_cast(output_data) + offset, data_num); - break; - case kNumberTypeUInt8: - Uint8ToFloat32(reinterpret_cast(input->MutableData()) + offset, - reinterpret_cast(output_data) + offset, data_num); - break; - case kNumberTypeInt32: - Int32ToFloat32(reinterpret_cast(input->MutableData()) + offset, - reinterpret_cast(output_data) + offset, data_num); - break; - case kNumberTypeFloat16: - Fp16ToFloat32(reinterpret_cast(input->MutableData()) + offset, - reinterpret_cast(output_data) + offset, data_num); - break; - default: - MS_LOG(ERROR) << "Unsupported input data type " << input_data_type; - return RET_ERROR; - } + return CastToFp32(input, output, offset, data_num); } return RET_OK; } @@ -132,6 +148,7 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Cast, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Cast, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Cast, LiteKernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt64, PrimitiveType_Cast, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Cast, LiteKernelCreator) #ifndef ENABLE_ARM REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Cast, LiteKernelCreator) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/cast_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/cast_fp32.h index cd976f044b..1e3a63af44 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/cast_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/cast_fp32.h @@ -38,6 +38,7 @@ class CastCPUKernel : public LiteKernel { int DoCast(int thread_id); private: + int CastToFp32(lite::Tensor *input, lite::Tensor *output, int offset, int data_num); int stride_; int data_num_; }; diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 5a1e996312..cdaff50788 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -233,7 +233,8 @@ if(ENABLE_CONVERTER) ${LITE_DIR}/tools/optimizer/fusion/conv_bn_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/constant_folding_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/quant_dtype_cast_fusion.cc - ${LITE_DIR}/tools/optimizer/fusion/layer_norm_fusion.cc + ${LITE_DIR}/tools/optimizer/fusion/tf_norm_fusion.cc + ${LITE_DIR}/tools/optimizer/fusion/onnx_layer_norm_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/batchmatmul_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/sigmoid_mul_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/conv_conv_fusion.cc diff --git a/mindspore/lite/test/models_npu.cfg b/mindspore/lite/test/models_npu.cfg index c19e5c6833..a7f1738079 100644 --- a/mindspore/lite/test/models_npu.cfg +++ b/mindspore/lite/test/models_npu.cfg @@ -38,7 +38,7 @@ ml_video_edit_img_segment 1 ml_video_edit_video_segment_gauss_adaptis_part1 2 ml_video_edit_generate_filter.pb 1 ml_video_edit_img_segment_adaptise.pb 0.5 2 -ml_video_edit_video_segment_gauss_adaptis_part2.pb 3 2 +ml_video_edit_video_segment_gauss_adaptis_part2.pb 10 2 ml_video_edit_person_divison_pic 0.5 ml_video_edit_person_divison_video 13 2 ml_video_edit_imitate_filter.onnx 230 diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index d7f79b6952..2186d24bca 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -51,6 +51,8 @@ static const std::vector nhwcOpList = {schema::PrimitiveT schema::PrimitiveType_SpaceToBatch, schema::PrimitiveType_SpaceToBatchND}; +static const std::vector nchwOpList = {schema::PrimitiveType_InstanceNorm}; + static const std::vector nhwcOpAllInputList = { schema::PrimitiveType_AvgPoolGrad, schema::PrimitiveType_MaxPoolGrad, schema::PrimitiveType_ActivationGrad, schema::PrimitiveType_Conv2DBackpropFilterFusion, @@ -153,6 +155,8 @@ std::vector Getfp32FullOpList() { return fp32FullOpList; std::vector GetNhwcOpList() { return nhwcOpList; } +std::vector GetNchwOpList() { return nchwOpList; } + std::unordered_map> GetExtNhwcIndexes() { return extNhwcInsertIndex; } std::vector GetNhwcAllInputOpList() { return nhwcOpAllInputList; } diff --git a/mindspore/lite/tools/common/node_util.h b/mindspore/lite/tools/common/node_util.h index c630a8e305..1e4d9f4980 100644 --- a/mindspore/lite/tools/common/node_util.h +++ b/mindspore/lite/tools/common/node_util.h @@ -60,6 +60,8 @@ std::vector GetInsertOpList(); std::vector GetNhwcOpList(); +std::vector GetNchwOpList(); + std::vector GetNhwcAllInputOpList(); std::unordered_map> GetExtNhwcIndexes(); diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index cf20f723de..7f9a28fccb 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -44,7 +44,8 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/fusion/conv_tuplegetitem_fusion.cc ../optimizer/fusion/constant_folding_fusion.cc ../optimizer/fusion/quant_dtype_cast_fusion.cc - ../optimizer/fusion/layer_norm_fusion.cc + ../optimizer/fusion/tf_norm_fusion.cc + ../optimizer/fusion/onnx_layer_norm_fusion.cc ../optimizer/fusion/batchmatmul_fusion.cc ../optimizer/fusion/sigmoid_mul_fusion.cc ../optimizer/fusion/conv_conv_fusion.cc diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 5b65c9c939..6da6f8496d 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -27,7 +27,8 @@ #include "tools/optimizer/fusion/conv_bn_fusion.h" #include "tools/optimizer/fusion/conv_tuplegetitem_fusion.h" #include "tools/optimizer/fusion/constant_folding_fusion.h" -#include "tools/optimizer/fusion/layer_norm_fusion.h" +#include "tools/optimizer/fusion/tf_norm_fusion.h" +#include "tools/optimizer/fusion/onnx_layer_norm_fusion.h" #include "tools/optimizer/fusion/batchmatmul_fusion.h" #include "tools/optimizer/fusion/sigmoid_mul_fusion.h" #include "tools/optimizer/fusion/conv_conv_fusion.h" @@ -77,7 +78,8 @@ int AnfTransform::AddFusionPass(const std::shared_ptr &opti auto conv_scale_pass = std::make_shared(); conv_scale_pass->SetFmkType(config->fmk); fusion_pm->AddPass(conv_scale_pass); - fusion_pm->AddPass(std::make_shared()); + fusion_pm->AddPass(std::make_shared()); + fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc index 58d9723cc8..f8ea5c98af 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc @@ -48,7 +48,12 @@ STATUS FormatTransPass::Run(schema::MetaGraphT *graph) { STATUS FormatTransPass::GetInsertFormatTrans(const schema::CNodeT &node, FormatTransNodeType *beforeNodeType, FormatTransNodeType *afterNodeType) { if (fmk_type_ == converter::FmkType_TFLITE) { // inference by nhwc - return RET_NO_CHANGE; + if (!IsContain(GetNchwOpList(), GetCNodeTType(node))) { + return RET_NO_CHANGE; + } + *beforeNodeType = kNHWC2NCHW; + *afterNodeType = kNCHW2NHWC; + return RET_OK; } else if (fmk_type_ == converter::FmkType_CAFFE || fmk_type_ == converter::FmkType_MS || fmk_type_ == converter::FmkType_ONNX) { if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) { @@ -63,6 +68,11 @@ STATUS FormatTransPass::GetInsertFormatTrans(const schema::CNodeT &node, FormatT *afterNodeType = kNHWC2NCHW; return RET_OK; } + if (IsContain(GetNchwOpList(), GetCNodeTType(node))) { + *beforeNodeType = kNHWC2NCHW; + *afterNodeType = kNCHW2NHWC; + return RET_OK; + } return RET_NO_CHANGE; } MS_LOG(ERROR) << "Unsupported fmk: " << fmk_type_; diff --git a/mindspore/lite/tools/optimizer/fusion/layer_norm_fusion.h b/mindspore/lite/tools/optimizer/fusion/layer_norm_fusion.h deleted file mode 100644 index 75d6bb44d6..0000000000 --- a/mindspore/lite/tools/optimizer/fusion/layer_norm_fusion.h +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2020-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_ -#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_ - -#include -#include -#include -#include "backend/optimizer/common/optimizer.h" -#include "utils/utils.h" - -namespace mindspore { -namespace opt { - -class LayerNormFusion : public PatternProcessPass { - public: - explicit LayerNormFusion(const std::string &name = "layer_norm_fusion", bool multigraph = true) - : PatternProcessPass(name, multigraph) { - input_ = std::make_shared(); - mean1_ = std::make_shared(); - mean1_axes_ = std::make_shared(); - mean2_ = std::make_shared(); - mean2_axes_ = std::make_shared(); - gamma_ = std::make_shared(); - beta_ = std::make_shared(); - epsilon_ = std::make_shared(); - } - - ~LayerNormFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - bool GetAxis(const CNodePtr &input_cnode, const std::vector &mean_axes, const std::vector ¶ms_shape, - int *begin_norm_axis, int *begin_params_axis) const; - bool CheckPattern(const EquivPtr &equiv, float *epsilon, int *begin_norm_axis, int *begin_params_axis) const; - CNodePtr CreateLayerNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, float epsilon, - int begin_norm_axis, int begin_params_axis) const; - VarPtr input_ = nullptr; - VarPtr mean1_ = nullptr; - VarPtr mean1_axes_ = nullptr; - VarPtr mean2_ = nullptr; - VarPtr mean2_axes_ = nullptr; - VarPtr gamma_ = nullptr; - VarPtr beta_ = nullptr; - VarPtr epsilon_ = nullptr; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/onnx_layer_norm_fusion.cc b/mindspore/lite/tools/optimizer/fusion/onnx_layer_norm_fusion.cc new file mode 100644 index 0000000000..85eae92b43 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/onnx_layer_norm_fusion.cc @@ -0,0 +1,37 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/optimizer/fusion/onnx_layer_norm_fusion.h" +#include +#include "ops/rsqrt.h" +#include "tools/optimizer/common/gllo_utils.h" + +namespace mindspore { +namespace opt { +const BaseRef OnnxLayerNormFusion::DefinePattern() const { + VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_}); + VectorRef sub1_ref = VectorRef({std::make_shared(IsSubNode), input_, mean1_ref}); + VectorRef sub2_ref = VectorRef({std::make_shared(IsSubNode), input_, mean1_ref}); + VectorRef pow_ref = VectorRef({std::make_shared(IsPowNode), sub2_ref, std::make_shared()}); + VectorRef mean2_ref = VectorRef({mean2_, pow_ref, mean2_axes_}); + VectorRef add1_ref = VectorRef({std::make_shared(IsAddNode), mean2_ref, epsilon_}); + VectorRef sqrt_ref = VectorRef({std::make_shared(IsSqrtNode), add1_ref}); + VectorRef div_ref = VectorRef({std::make_shared(IsDivNode), sub1_ref, sqrt_ref}); + VectorRef mul_ref = VectorRef({std::make_shared(IsMulNode), gamma_, div_ref}); + VectorRef add2_ref = VectorRef({std::make_shared(IsAddNode), mul_ref, beta_}); + return add2_ref; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/fusion/onnx_layer_norm_fusion.h b/mindspore/lite/tools/optimizer/fusion/onnx_layer_norm_fusion.h new file mode 100644 index 0000000000..bab1c11bc7 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/onnx_layer_norm_fusion.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_LAYER_NORM_FUSION_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_LAYER_NORM_FUSION_H_ + +#include +#include +#include +#include "tools/optimizer/fusion/tf_norm_fusion.h" +#include "backend/optimizer/common/optimizer.h" +#include "utils/utils.h" + +namespace mindspore { +namespace opt { + +class OnnxLayerNormFusion : public TfNormFusion { + public: + explicit OnnxLayerNormFusion(const std::string &name = "onnx_layer_norm_fusion", bool multigraph = true) + : TfNormFusion(name, multigraph) {} + + ~OnnxLayerNormFusion() override = default; + const BaseRef DefinePattern() const override; +}; + +inline bool IsPowNode(const BaseRef &n) { + if (utils::isa(n)) { + return CheckPrimitiveType(utils::cast(n), prim::kPrimPowFusion); + } + return false; +} +inline bool IsSqrtNode(const BaseRef &n) { + if (utils::isa(n)) { + return CheckPrimitiveType(utils::cast(n), prim::kPrimSqrt); + } + return false; +} +inline bool IsDivNode(const BaseRef &n) { + if (utils::isa(n)) { + return CheckPrimitiveType(utils::cast(n), prim::kPrimDiv) || + CheckPrimitiveType(utils::cast(n), std::make_shared("DivFusion")); + } + return false; +} +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_LAYER_NORM_FUSION_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/layer_norm_fusion.cc b/mindspore/lite/tools/optimizer/fusion/tf_norm_fusion.cc similarity index 69% rename from mindspore/lite/tools/optimizer/fusion/layer_norm_fusion.cc rename to mindspore/lite/tools/optimizer/fusion/tf_norm_fusion.cc index 1d744066a2..446de64873 100644 --- a/mindspore/lite/tools/optimizer/fusion/layer_norm_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/tf_norm_fusion.cc @@ -13,11 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "tools/optimizer/fusion/layer_norm_fusion.h" +#include "tools/optimizer/fusion/tf_norm_fusion.h" #include #include "ops/fusion/layer_norm_fusion.h" #include "ops/fusion/reduce_fusion.h" #include "ops/rsqrt.h" +#include "mindspore/core/ops/instance_norm.h" #include "src/param_value_lite.h" #include "utils/utils.h" #include "tools/optimizer/common/gllo_utils.h" @@ -26,41 +27,6 @@ namespace mindspore { namespace opt { namespace { -bool IsAddNode(const BaseRef &n) { - if (utils::isa(n)) { - return CheckPrimitiveType(utils::cast(n), prim::kPrimAddFusion); - } - return false; -} - -bool IsSquaredDifferenceNode(const BaseRef &n) { - if (utils::isa(n)) { - return CheckPrimitiveType(utils::cast(n), prim::kPrimSquaredDifference); - } - return false; -} - -bool IsRsqrtNode(const BaseRef &n) { - if (utils::isa(n)) { - return CheckPrimitiveType(utils::cast(n), prim::kPrimRsqrt); - } - return false; -} - -bool IsMulNode(const BaseRef &n) { - if (utils::isa(n)) { - return CheckPrimitiveType(utils::cast(n), prim::kPrimMulFusion); - } - return false; -} - -bool IsSubNode(const BaseRef &n) { - if (utils::isa(n)) { - return CheckPrimitiveType(utils::cast(n), prim::kPrimSubFusion); - } - return false; -} - lite::STATUS GetReduceAxes(const BaseRef &n, std::vector *axes) { MS_ASSERT(node != nullptr); if (utils::isa(n)) { @@ -106,7 +72,7 @@ bool IsReduceNode(const EquivPtr &equiv, const VarPtr &input_prim, const VarPtr } } // namespace -const BaseRef LayerNormFusion::DefinePattern() const { +const BaseRef TfNormFusion::DefinePattern() const { VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_}); auto squared_diffference1 = std::make_shared(IsSquaredDifferenceNode); VectorRef squared_diffference1_ref = VectorRef({squared_diffference1, input_, mean1_ref}); @@ -128,13 +94,26 @@ const BaseRef LayerNormFusion::DefinePattern() const { return add2_ref; } -CNodePtr LayerNormFusion::CreateLayerNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, float epsilon, - int begin_norm_axis, int begin_params_axis) const { +CNodePtr TfNormFusion::CreateNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, + const schema::PrimitiveType type, float epsilon, int begin_norm_axis, + int begin_params_axis) const { MS_ASSERT(func_graph != nullptr); MS_ASSERT(equiv != nullptr); - auto layer_norm_primitive = std::make_shared(); - layer_norm_primitive->Init(begin_norm_axis, begin_params_axis, epsilon); - auto value_node = NewValueNode(layer_norm_primitive); + auto norm_primitive = std::make_unique(); + norm_primitive->value.type = type; + PrimitiveCPtr primitive = nullptr; + if (type == schema::PrimitiveType_LayerNormFusion) { + auto layer_norm_primitive = std::make_shared(); + layer_norm_primitive->Init(begin_norm_axis, begin_params_axis, epsilon, true); + primitive = layer_norm_primitive; + } else if (type == schema::PrimitiveType_InstanceNorm) { + auto instance_norm_primitive = std::make_shared(); + instance_norm_primitive->Init(epsilon); + primitive = instance_norm_primitive; + } else { + return nullptr; + } + auto value_node = NewValueNode(primitive); std::vector new_node_inputs = {value_node}; auto input_node = utils::cast((*equiv)[input_]); MS_ASSERT(input_node != nullptr); @@ -149,10 +128,11 @@ CNodePtr LayerNormFusion::CreateLayerNormNode(const FuncGraphPtr &func_graph, co return new_node; } -bool LayerNormFusion::GetAxis(const CNodePtr &input_cnode, const std::vector &mean_axes, - const std::vector ¶ms_shape, int *begin_norm_axis, - int *begin_params_axis) const { +bool TfNormFusion::GetNormTypeAndAxis(const CNodePtr &input_cnode, const std::vector &mean_axes, + const std::vector ¶ms_shape, schema::PrimitiveType *type, + int *begin_norm_axis, int *begin_params_axis) const { MS_ASSERT(input_node != nullptr); + MS_ASSERT(type != nullptr); MS_ASSERT(begin_norm_axis != nullptr); MS_ASSERT(begin_params_axis != nullptr); auto abstract = input_cnode->abstract(); @@ -170,30 +150,44 @@ bool LayerNormFusion::GetAxis(const CNodePtr &input_cnode, const std::vector(abstract_tensor->BuildShape())->shape(); - if (mean_axes.back() + 1 != static_cast(shape.size())) { - MS_LOG(DEBUG) << "mean node is not reduce to last axis"; - return false; - } for (size_t i = 1; i < mean_axes.size(); ++i) { if (mean_axes[i] != mean_axes[i - 1] + 1) { MS_LOG(DEBUG) << "mean axes is not continuous"; return false; } } + if (shape.size() == 4 && mean_axes.size() == 2 && mean_axes[0] == 1 && mean_axes[1] == 2) { + if (params_shape.size() == 1 && params_shape.back() == shape.back()) { + *type = schema::PrimitiveType_InstanceNorm; + return true; + } + } + if (mean_axes.back() >= 0 && mean_axes.back() + 1 != static_cast(shape.size())) { + MS_LOG(DEBUG) << "mean node is not reduce to last axis"; + return false; + } + // there is no need to check params_shape *begin_norm_axis = mean_axes.front(); - *begin_params_axis = static_cast(shape.size()) - static_cast(params_shape.size()); - if (*begin_params_axis < 0) { - MS_LOG(DEBUG) << "LayerNorm begin_params_axis illegal, not fuse"; - return false; + if (*begin_norm_axis >= 0) { + *begin_params_axis = static_cast(shape.size()) - static_cast(params_shape.size()); + if (*begin_params_axis < 0) { + MS_LOG(DEBUG) << "LayerNorm begin_params_axis illegal, not fuse"; + return false; + } + } else { + *begin_params_axis = -static_cast(params_shape.size()); } + + *type = schema::PrimitiveType_LayerNormFusion; return true; } -bool LayerNormFusion::CheckPattern(const EquivPtr &equiv, float *epsilon, int *begin_norm_axis, - int *begin_params_axis) const { +bool TfNormFusion::CheckPattern(const EquivPtr &equiv, schema::PrimitiveType *type, float *epsilon, + int *begin_norm_axis, int *begin_params_axis) const { MS_ASSERT(equiv != nullptr); MS_ASSERT(epsilon != nullptr); + MS_ASSERT(type != nullptr); MS_ASSERT(begin_norm_axis != nullptr); MS_ASSERT(begin_params_axis != nullptr); // beta @@ -243,9 +237,6 @@ bool LayerNormFusion::CheckPattern(const EquivPtr &equiv, float *epsilon, int *b if (mean1_axes != mean2_axes) { return false; } - if (mean1_axes.size() != gamma_shape.size() || mean1_axes.size() != beta_shape.size()) { - return false; - } if (gamma_shape != beta_shape) { return false; } @@ -254,14 +245,14 @@ bool LayerNormFusion::CheckPattern(const EquivPtr &equiv, float *epsilon, int *b } else { return false; } - if (!GetAxis(input_cnode, mean1_axes, gamma_shape, begin_norm_axis, begin_params_axis)) { + if (!GetNormTypeAndAxis(input_cnode, mean1_axes, gamma_shape, type, begin_norm_axis, begin_params_axis)) { return false; } return true; } -const AnfNodePtr LayerNormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { +const AnfNodePtr TfNormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { MS_ASSERT(func_graph != nullptr); MS_ASSERT(node != nullptr); MS_ASSERT(equiv != nullptr); @@ -273,14 +264,24 @@ const AnfNodePtr LayerNormFusion::Process(const FuncGraphPtr &func_graph, const float epsilon = 0.0f; int begin_norm_axis = 0; int begin_params_axis = 0; - if (!CheckPattern(equiv, &epsilon, &begin_norm_axis, &begin_params_axis)) { + schema::PrimitiveType type = schema::PrimitiveType_NONE; + if (!CheckPattern(equiv, &type, &epsilon, &begin_norm_axis, &begin_params_axis)) { + return nullptr; + } + auto norm_cnode = CreateNormNode(func_graph, equiv, type, epsilon, begin_norm_axis, begin_params_axis); + if (norm_cnode == nullptr) { + MS_LOG(DEBUG) << "create norm cnode failed"; return nullptr; } - auto layer_norm_cnode = CreateLayerNormNode(func_graph, equiv, epsilon, begin_norm_axis, begin_params_axis); - layer_norm_cnode->set_abstract(add2_cnode->abstract()->Clone()); - layer_norm_cnode->set_fullname_with_scope("layer_norm_" + add2_cnode->fullname_with_scope()); - MS_LOG(INFO) << "layernorm node:" << layer_norm_cnode->fullname_with_scope() << " fusion success"; - return layer_norm_cnode; + norm_cnode->set_abstract(add2_cnode->abstract()->Clone()); + if (type == schema::PrimitiveType_LayerNormFusion) { + norm_cnode->set_fullname_with_scope("layer_norm_" + add2_cnode->fullname_with_scope()); + MS_LOG(INFO) << "layer_norm node:" << norm_cnode->fullname_with_scope() << " fusion success"; + } else if (type == schema::PrimitiveType_InstanceNorm) { + norm_cnode->set_fullname_with_scope("instance_norm_" + add2_cnode->fullname_with_scope()); + MS_LOG(INFO) << "instance_norm node:" << norm_cnode->fullname_with_scope() << " fusion success"; + } + return norm_cnode; } } // namespace opt } // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/fusion/tf_norm_fusion.h b/mindspore/lite/tools/optimizer/fusion/tf_norm_fusion.h new file mode 100644 index 0000000000..64369b89a4 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/tf_norm_fusion.h @@ -0,0 +1,107 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_ + +#include +#include +#include +#include "schema/inner/model_generated.h" +#include "backend/optimizer/common/optimizer.h" +#include "utils/utils.h" +#include "tools/optimizer/common/gllo_utils.h" + +namespace mindspore { +namespace opt { + +/// fuse layer_norm, instance_norm into one operator +class TfNormFusion : public PatternProcessPass { + public: + explicit TfNormFusion(const std::string &name = "tf_norm_fusion", bool multigraph = true) + : PatternProcessPass(name, multigraph) { + input_ = std::make_shared(); + mean1_ = std::make_shared(); + mean1_axes_ = std::make_shared(); + mean2_ = std::make_shared(); + mean2_axes_ = std::make_shared(); + gamma_ = std::make_shared(); + beta_ = std::make_shared(); + epsilon_ = std::make_shared(); + } + + ~TfNormFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + bool GetNormTypeAndAxis(const CNodePtr &input_cnode, const std::vector &mean_axes, + const std::vector ¶ms_shape, schema::PrimitiveType *type, int *begin_norm_axis, + int *begin_params_axis) const; + bool CheckPattern(const EquivPtr &equiv, schema::PrimitiveType *type, float *epsilon, int *begin_norm_axis, + int *begin_params_axis) const; + CNodePtr CreateNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const schema::PrimitiveType type, + float epsilon, int begin_norm_axis, int begin_params_axis) const; + + protected: + VarPtr input_ = nullptr; + VarPtr mean1_ = nullptr; + VarPtr mean1_axes_ = nullptr; + VarPtr mean2_ = nullptr; + VarPtr mean2_axes_ = nullptr; + VarPtr gamma_ = nullptr; + VarPtr beta_ = nullptr; + VarPtr epsilon_ = nullptr; +}; + +inline bool IsAddNode(const BaseRef &n) { + if (utils::isa(n)) { + return CheckPrimitiveType(utils::cast(n), prim::kPrimAddFusion); + } + return false; +} + +inline bool IsSquaredDifferenceNode(const BaseRef &n) { + if (utils::isa(n)) { + return CheckPrimitiveType(utils::cast(n), prim::kPrimSquaredDifference); + } + return false; +} + +inline bool IsRsqrtNode(const BaseRef &n) { + if (utils::isa(n)) { + return CheckPrimitiveType(utils::cast(n), prim::kPrimRsqrt); + } + return false; +} + +inline bool IsMulNode(const BaseRef &n) { + if (utils::isa(n)) { + return CheckPrimitiveType(utils::cast(n), prim::kPrimMulFusion); + } + return false; +} + +inline bool IsSubNode(const BaseRef &n) { + if (utils::isa(n)) { + return CheckPrimitiveType(utils::cast(n), prim::kPrimSubFusion); + } + return false; +} +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_