From: @wangzhe128 Reviewed-by: @hangangqiang,@zhanghaibo5 Signed-off-by: @hangangqiangtags/v1.2.0-rc1
| @@ -41,6 +41,20 @@ inline void Int32ToFloat32(const int32_t *input, float *output, int number) { | |||
| } | |||
| } | |||
| inline void Int64ToFloat32(const int64_t *input, float *output, int number) { | |||
| for (int i = 0; i < number; ++i) { | |||
| output[i] = (float)input[i]; | |||
| } | |||
| } | |||
| #ifdef ENABLE_FP16 | |||
| inline void Int64ToFp16(const int64_t *input, float16_t *output, int number) { | |||
| for (int i = 0; i < number; ++i) { | |||
| output[i] = (float16_t)input[i]; | |||
| } | |||
| } | |||
| #endif | |||
| inline void Fp16ToFloat32(const uint16_t *input, float *output, int number) { | |||
| for (int i = 0; i < number; ++i) { | |||
| output[i] = ShortToFloat32(input[i]); | |||
| @@ -82,6 +96,7 @@ inline void BoolToInt32(const bool *input, int32_t *output, int number) { | |||
| output[i] = (int32_t)input[i]; | |||
| } | |||
| } | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -36,7 +36,8 @@ int CastInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o | |||
| } | |||
| if (input->data_type_ != kNumberTypeBool && input->data_type_ != kNumberTypeUInt8 && | |||
| input->data_type_ != kNumberTypeInt8 && input->data_type_ != kNumberTypeInt32 && | |||
| input->data_type_ != kNumberTypeFloat32 && input->data_type_ != kNumberTypeFloat16) { | |||
| input->data_type_ != kNumberTypeInt64 && input->data_type_ != kNumberTypeFloat32 && | |||
| input->data_type_ != kNumberTypeFloat16) { | |||
| return NNACL_INPUT_TENSOR_ERROR; | |||
| } | |||
| @@ -121,6 +121,17 @@ int CastFp16CPUKernel::DoCast(int thread_id) { | |||
| MS_LOG(ERROR) << "Unsupported output data type " << output_data_type; | |||
| return RET_ERROR; | |||
| } | |||
| } else if (input_data_type == kNumberTypeInt64) { | |||
| switch (output_data_type) { | |||
| case kNumberTypeFloat16: | |||
| Int64ToFloat32(reinterpret_cast<int64_t *>(input->MutableData()) + offset, | |||
| reinterpret_cast<float *>(output_data) + offset, data_num); | |||
| break; | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupported output data type " << output_data_type; | |||
| return RET_ERROR; | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported input data type " << input_data_type; | |||
| return RET_ERROR; | |||
| @@ -136,4 +147,5 @@ int CastFp16CPUKernel::Run() { | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Cast, LiteKernelCreator<CastFp16CPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeInt64, PrimitiveType_Cast, LiteKernelCreator<CastFp16CPUKernel>) | |||
| } // namespace mindspore::kernel | |||
| @@ -53,6 +53,37 @@ int CastCPUKernel::ReSize() { | |||
| return RET_OK; | |||
| } | |||
| int CastCPUKernel::CastToFp32(lite::Tensor *input, lite::Tensor *output, int offset, int data_num) { | |||
| auto input_data_type = input->data_type(); | |||
| auto output_data = output->data_c(); | |||
| switch (input_data_type) { | |||
| case kNumberTypeBool: | |||
| BoolToFloat32(reinterpret_cast<bool *>(input->MutableData()) + offset, | |||
| reinterpret_cast<float *>(output_data) + offset, data_num); | |||
| break; | |||
| case kNumberTypeUInt8: | |||
| Uint8ToFloat32(reinterpret_cast<uint8_t *>(input->MutableData()) + offset, | |||
| reinterpret_cast<float *>(output_data) + offset, data_num); | |||
| break; | |||
| case kNumberTypeInt32: | |||
| Int32ToFloat32(reinterpret_cast<int32_t *>(input->MutableData()) + offset, | |||
| reinterpret_cast<float *>(output_data) + offset, data_num); | |||
| break; | |||
| case kNumberTypeFloat16: | |||
| Fp16ToFloat32(reinterpret_cast<uint16_t *>(input->MutableData()) + offset, | |||
| reinterpret_cast<float *>(output_data) + offset, data_num); | |||
| break; | |||
| case kNumberTypeInt64: | |||
| Int64ToFloat32(reinterpret_cast<int64_t *>(input->MutableData()) + offset, | |||
| reinterpret_cast<float *>(output_data) + offset, data_num); | |||
| break; | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupported input data type " << input_data_type; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int CastCPUKernel::DoCast(int thread_id) { | |||
| auto input = in_tensors_.at(0); | |||
| int data_num = MSMIN(stride_, data_num_ - thread_id * stride_); | |||
| @@ -91,32 +122,17 @@ int CastCPUKernel::DoCast(int thread_id) { | |||
| } else if (input_data_type == kNumberTypeBool && output_data_type == kNumberTypeInt32) { | |||
| BoolToInt32(reinterpret_cast<bool *>(input->data_c()) + offset, reinterpret_cast<int32_t *>(output_data) + offset, | |||
| data_num); | |||
| #ifdef ENABLE_FP16 | |||
| } else if (input_data_type == kNumberTypeInt64 && output_data_type == kNumberTypeFloat16) { | |||
| Int64ToFp16(reinterpret_cast<int64_t *>(input->data_c()) + offset, | |||
| reinterpret_cast<float16_t *>(output_data) + offset, data_num); | |||
| #endif | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type; | |||
| return RET_ERROR; | |||
| } | |||
| } else { | |||
| switch (input_data_type) { | |||
| case kNumberTypeBool: | |||
| BoolToFloat32(reinterpret_cast<bool *>(input->MutableData()) + offset, | |||
| reinterpret_cast<float *>(output_data) + offset, data_num); | |||
| break; | |||
| case kNumberTypeUInt8: | |||
| Uint8ToFloat32(reinterpret_cast<uint8_t *>(input->MutableData()) + offset, | |||
| reinterpret_cast<float *>(output_data) + offset, data_num); | |||
| break; | |||
| case kNumberTypeInt32: | |||
| Int32ToFloat32(reinterpret_cast<int32_t *>(input->MutableData()) + offset, | |||
| reinterpret_cast<float *>(output_data) + offset, data_num); | |||
| break; | |||
| case kNumberTypeFloat16: | |||
| Fp16ToFloat32(reinterpret_cast<uint16_t *>(input->MutableData()) + offset, | |||
| reinterpret_cast<float *>(output_data) + offset, data_num); | |||
| break; | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupported input data type " << input_data_type; | |||
| return RET_ERROR; | |||
| } | |||
| return CastToFp32(input, output, offset, data_num); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -132,6 +148,7 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Cast, LiteKernelCreator<CastC | |||
| REG_KERNEL(kCPU, kNumberTypeUInt8, PrimitiveType_Cast, LiteKernelCreator<CastCPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Cast, LiteKernelCreator<CastCPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Cast, LiteKernelCreator<CastCPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeInt64, PrimitiveType_Cast, LiteKernelCreator<CastCPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Cast, LiteKernelCreator<CastCPUKernel>) | |||
| #ifndef ENABLE_ARM | |||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Cast, LiteKernelCreator<CastCPUKernel>) | |||
| @@ -38,6 +38,7 @@ class CastCPUKernel : public LiteKernel { | |||
| int DoCast(int thread_id); | |||
| private: | |||
| int CastToFp32(lite::Tensor *input, lite::Tensor *output, int offset, int data_num); | |||
| int stride_; | |||
| int data_num_; | |||
| }; | |||
| @@ -233,7 +233,8 @@ if(ENABLE_CONVERTER) | |||
| ${LITE_DIR}/tools/optimizer/fusion/conv_bn_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/constant_folding_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/quant_dtype_cast_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/layer_norm_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/tf_norm_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/onnx_layer_norm_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/batchmatmul_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/sigmoid_mul_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/conv_conv_fusion.cc | |||
| @@ -38,7 +38,7 @@ ml_video_edit_img_segment 1 | |||
| ml_video_edit_video_segment_gauss_adaptis_part1 2 | |||
| ml_video_edit_generate_filter.pb 1 | |||
| ml_video_edit_img_segment_adaptise.pb 0.5 2 | |||
| ml_video_edit_video_segment_gauss_adaptis_part2.pb 3 2 | |||
| ml_video_edit_video_segment_gauss_adaptis_part2.pb 10 2 | |||
| ml_video_edit_person_divison_pic 0.5 | |||
| ml_video_edit_person_divison_video 13 2 | |||
| ml_video_edit_imitate_filter.onnx 230 | |||
| @@ -51,6 +51,8 @@ static const std::vector<schema::PrimitiveType> nhwcOpList = {schema::PrimitiveT | |||
| schema::PrimitiveType_SpaceToBatch, | |||
| schema::PrimitiveType_SpaceToBatchND}; | |||
| static const std::vector<schema::PrimitiveType> nchwOpList = {schema::PrimitiveType_InstanceNorm}; | |||
| static const std::vector<schema::PrimitiveType> nhwcOpAllInputList = { | |||
| schema::PrimitiveType_AvgPoolGrad, schema::PrimitiveType_MaxPoolGrad, | |||
| schema::PrimitiveType_ActivationGrad, schema::PrimitiveType_Conv2DBackpropFilterFusion, | |||
| @@ -153,6 +155,8 @@ std::vector<schema::PrimitiveType> Getfp32FullOpList() { return fp32FullOpList; | |||
| std::vector<schema::PrimitiveType> GetNhwcOpList() { return nhwcOpList; } | |||
| std::vector<schema::PrimitiveType> GetNchwOpList() { return nchwOpList; } | |||
| std::unordered_map<schema::PrimitiveType, std::vector<int>> GetExtNhwcIndexes() { return extNhwcInsertIndex; } | |||
| std::vector<schema::PrimitiveType> GetNhwcAllInputOpList() { return nhwcOpAllInputList; } | |||
| @@ -60,6 +60,8 @@ std::vector<schema::PrimitiveType> GetInsertOpList(); | |||
| std::vector<schema::PrimitiveType> GetNhwcOpList(); | |||
| std::vector<schema::PrimitiveType> GetNchwOpList(); | |||
| std::vector<schema::PrimitiveType> GetNhwcAllInputOpList(); | |||
| std::unordered_map<schema::PrimitiveType, std::vector<int>> GetExtNhwcIndexes(); | |||
| @@ -44,7 +44,8 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ../optimizer/fusion/conv_tuplegetitem_fusion.cc | |||
| ../optimizer/fusion/constant_folding_fusion.cc | |||
| ../optimizer/fusion/quant_dtype_cast_fusion.cc | |||
| ../optimizer/fusion/layer_norm_fusion.cc | |||
| ../optimizer/fusion/tf_norm_fusion.cc | |||
| ../optimizer/fusion/onnx_layer_norm_fusion.cc | |||
| ../optimizer/fusion/batchmatmul_fusion.cc | |||
| ../optimizer/fusion/sigmoid_mul_fusion.cc | |||
| ../optimizer/fusion/conv_conv_fusion.cc | |||
| @@ -27,7 +27,8 @@ | |||
| #include "tools/optimizer/fusion/conv_bn_fusion.h" | |||
| #include "tools/optimizer/fusion/conv_tuplegetitem_fusion.h" | |||
| #include "tools/optimizer/fusion/constant_folding_fusion.h" | |||
| #include "tools/optimizer/fusion/layer_norm_fusion.h" | |||
| #include "tools/optimizer/fusion/tf_norm_fusion.h" | |||
| #include "tools/optimizer/fusion/onnx_layer_norm_fusion.h" | |||
| #include "tools/optimizer/fusion/batchmatmul_fusion.h" | |||
| #include "tools/optimizer/fusion/sigmoid_mul_fusion.h" | |||
| #include "tools/optimizer/fusion/conv_conv_fusion.h" | |||
| @@ -77,7 +78,8 @@ int AnfTransform::AddFusionPass(const std::shared_ptr<opt::GraphOptimizer> &opti | |||
| auto conv_scale_pass = std::make_shared<opt::ConvScaleFusion>(); | |||
| conv_scale_pass->SetFmkType(config->fmk); | |||
| fusion_pm->AddPass(conv_scale_pass); | |||
| fusion_pm->AddPass(std::make_shared<opt::LayerNormFusion>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::TfNormFusion>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::OnnxLayerNormFusion>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::BatchMatMulFusion>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::SigmoidMulFusion>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::ConvActivationFusion>()); | |||
| @@ -48,7 +48,12 @@ STATUS FormatTransPass::Run(schema::MetaGraphT *graph) { | |||
| STATUS FormatTransPass::GetInsertFormatTrans(const schema::CNodeT &node, FormatTransNodeType *beforeNodeType, | |||
| FormatTransNodeType *afterNodeType) { | |||
| if (fmk_type_ == converter::FmkType_TFLITE) { // inference by nhwc | |||
| return RET_NO_CHANGE; | |||
| if (!IsContain(GetNchwOpList(), GetCNodeTType(node))) { | |||
| return RET_NO_CHANGE; | |||
| } | |||
| *beforeNodeType = kNHWC2NCHW; | |||
| *afterNodeType = kNCHW2NHWC; | |||
| return RET_OK; | |||
| } else if (fmk_type_ == converter::FmkType_CAFFE || fmk_type_ == converter::FmkType_MS || | |||
| fmk_type_ == converter::FmkType_ONNX) { | |||
| if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) { | |||
| @@ -63,6 +68,11 @@ STATUS FormatTransPass::GetInsertFormatTrans(const schema::CNodeT &node, FormatT | |||
| *afterNodeType = kNHWC2NCHW; | |||
| return RET_OK; | |||
| } | |||
| if (IsContain(GetNchwOpList(), GetCNodeTType(node))) { | |||
| *beforeNodeType = kNHWC2NCHW; | |||
| *afterNodeType = kNCHW2NHWC; | |||
| return RET_OK; | |||
| } | |||
| return RET_NO_CHANGE; | |||
| } | |||
| MS_LOG(ERROR) << "Unsupported fmk: " << fmk_type_; | |||
| @@ -1,65 +0,0 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "utils/utils.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class LayerNormFusion : public PatternProcessPass { | |||
| public: | |||
| explicit LayerNormFusion(const std::string &name = "layer_norm_fusion", bool multigraph = true) | |||
| : PatternProcessPass(name, multigraph) { | |||
| input_ = std::make_shared<Var>(); | |||
| mean1_ = std::make_shared<Var>(); | |||
| mean1_axes_ = std::make_shared<Var>(); | |||
| mean2_ = std::make_shared<Var>(); | |||
| mean2_axes_ = std::make_shared<Var>(); | |||
| gamma_ = std::make_shared<Var>(); | |||
| beta_ = std::make_shared<Var>(); | |||
| epsilon_ = std::make_shared<Var>(); | |||
| } | |||
| ~LayerNormFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| bool GetAxis(const CNodePtr &input_cnode, const std::vector<int> &mean_axes, const std::vector<int> ¶ms_shape, | |||
| int *begin_norm_axis, int *begin_params_axis) const; | |||
| bool CheckPattern(const EquivPtr &equiv, float *epsilon, int *begin_norm_axis, int *begin_params_axis) const; | |||
| CNodePtr CreateLayerNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, float epsilon, | |||
| int begin_norm_axis, int begin_params_axis) const; | |||
| VarPtr input_ = nullptr; | |||
| VarPtr mean1_ = nullptr; | |||
| VarPtr mean1_axes_ = nullptr; | |||
| VarPtr mean2_ = nullptr; | |||
| VarPtr mean2_axes_ = nullptr; | |||
| VarPtr gamma_ = nullptr; | |||
| VarPtr beta_ = nullptr; | |||
| VarPtr epsilon_ = nullptr; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_ | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/fusion/onnx_layer_norm_fusion.h" | |||
| #include <memory> | |||
| #include "ops/rsqrt.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const BaseRef OnnxLayerNormFusion::DefinePattern() const { | |||
| VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_}); | |||
| VectorRef sub1_ref = VectorRef({std::make_shared<CondVar>(IsSubNode), input_, mean1_ref}); | |||
| VectorRef sub2_ref = VectorRef({std::make_shared<CondVar>(IsSubNode), input_, mean1_ref}); | |||
| VectorRef pow_ref = VectorRef({std::make_shared<CondVar>(IsPowNode), sub2_ref, std::make_shared<Var>()}); | |||
| VectorRef mean2_ref = VectorRef({mean2_, pow_ref, mean2_axes_}); | |||
| VectorRef add1_ref = VectorRef({std::make_shared<CondVar>(IsAddNode), mean2_ref, epsilon_}); | |||
| VectorRef sqrt_ref = VectorRef({std::make_shared<CondVar>(IsSqrtNode), add1_ref}); | |||
| VectorRef div_ref = VectorRef({std::make_shared<CondVar>(IsDivNode), sub1_ref, sqrt_ref}); | |||
| VectorRef mul_ref = VectorRef({std::make_shared<CondVar>(IsMulNode), gamma_, div_ref}); | |||
| VectorRef add2_ref = VectorRef({std::make_shared<CondVar>(IsAddNode), mul_ref, beta_}); | |||
| return add2_ref; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,60 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_LAYER_NORM_FUSION_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_LAYER_NORM_FUSION_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "tools/optimizer/fusion/tf_norm_fusion.h" | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "utils/utils.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class OnnxLayerNormFusion : public TfNormFusion { | |||
| public: | |||
| explicit OnnxLayerNormFusion(const std::string &name = "onnx_layer_norm_fusion", bool multigraph = true) | |||
| : TfNormFusion(name, multigraph) {} | |||
| ~OnnxLayerNormFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| }; | |||
| inline bool IsPowNode(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimPowFusion); | |||
| } | |||
| return false; | |||
| } | |||
| inline bool IsSqrtNode(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSqrt); | |||
| } | |||
| return false; | |||
| } | |||
| inline bool IsDivNode(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimDiv) || | |||
| CheckPrimitiveType(utils::cast<AnfNodePtr>(n), std::make_shared<Primitive>("DivFusion")); | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_LAYER_NORM_FUSION_H_ | |||
| @@ -13,11 +13,12 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/fusion/layer_norm_fusion.h" | |||
| #include "tools/optimizer/fusion/tf_norm_fusion.h" | |||
| #include <memory> | |||
| #include "ops/fusion/layer_norm_fusion.h" | |||
| #include "ops/fusion/reduce_fusion.h" | |||
| #include "ops/rsqrt.h" | |||
| #include "mindspore/core/ops/instance_norm.h" | |||
| #include "src/param_value_lite.h" | |||
| #include "utils/utils.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| @@ -26,41 +27,6 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| bool IsAddNode(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimAddFusion); | |||
| } | |||
| return false; | |||
| } | |||
| bool IsSquaredDifferenceNode(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSquaredDifference); | |||
| } | |||
| return false; | |||
| } | |||
| bool IsRsqrtNode(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimRsqrt); | |||
| } | |||
| return false; | |||
| } | |||
| bool IsMulNode(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimMulFusion); | |||
| } | |||
| return false; | |||
| } | |||
| bool IsSubNode(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSubFusion); | |||
| } | |||
| return false; | |||
| } | |||
| lite::STATUS GetReduceAxes(const BaseRef &n, std::vector<int> *axes) { | |||
| MS_ASSERT(node != nullptr); | |||
| if (utils::isa<ParameterPtr>(n)) { | |||
| @@ -106,7 +72,7 @@ bool IsReduceNode(const EquivPtr &equiv, const VarPtr &input_prim, const VarPtr | |||
| } | |||
| } // namespace | |||
| const BaseRef LayerNormFusion::DefinePattern() const { | |||
| const BaseRef TfNormFusion::DefinePattern() const { | |||
| VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_}); | |||
| auto squared_diffference1 = std::make_shared<CondVar>(IsSquaredDifferenceNode); | |||
| VectorRef squared_diffference1_ref = VectorRef({squared_diffference1, input_, mean1_ref}); | |||
| @@ -128,13 +94,26 @@ const BaseRef LayerNormFusion::DefinePattern() const { | |||
| return add2_ref; | |||
| } | |||
| CNodePtr LayerNormFusion::CreateLayerNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, float epsilon, | |||
| int begin_norm_axis, int begin_params_axis) const { | |||
| CNodePtr TfNormFusion::CreateNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, | |||
| const schema::PrimitiveType type, float epsilon, int begin_norm_axis, | |||
| int begin_params_axis) const { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(equiv != nullptr); | |||
| auto layer_norm_primitive = std::make_shared<ops::LayerNormFusion>(); | |||
| layer_norm_primitive->Init(begin_norm_axis, begin_params_axis, epsilon); | |||
| auto value_node = NewValueNode(layer_norm_primitive); | |||
| auto norm_primitive = std::make_unique<schema::PrimitiveT>(); | |||
| norm_primitive->value.type = type; | |||
| PrimitiveCPtr primitive = nullptr; | |||
| if (type == schema::PrimitiveType_LayerNormFusion) { | |||
| auto layer_norm_primitive = std::make_shared<ops::LayerNormFusion>(); | |||
| layer_norm_primitive->Init(begin_norm_axis, begin_params_axis, epsilon, true); | |||
| primitive = layer_norm_primitive; | |||
| } else if (type == schema::PrimitiveType_InstanceNorm) { | |||
| auto instance_norm_primitive = std::make_shared<ops::InstanceNorm>(); | |||
| instance_norm_primitive->Init(epsilon); | |||
| primitive = instance_norm_primitive; | |||
| } else { | |||
| return nullptr; | |||
| } | |||
| auto value_node = NewValueNode(primitive); | |||
| std::vector<AnfNodePtr> new_node_inputs = {value_node}; | |||
| auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_]); | |||
| MS_ASSERT(input_node != nullptr); | |||
| @@ -149,10 +128,11 @@ CNodePtr LayerNormFusion::CreateLayerNormNode(const FuncGraphPtr &func_graph, co | |||
| return new_node; | |||
| } | |||
| bool LayerNormFusion::GetAxis(const CNodePtr &input_cnode, const std::vector<int> &mean_axes, | |||
| const std::vector<int> ¶ms_shape, int *begin_norm_axis, | |||
| int *begin_params_axis) const { | |||
| bool TfNormFusion::GetNormTypeAndAxis(const CNodePtr &input_cnode, const std::vector<int> &mean_axes, | |||
| const std::vector<int> ¶ms_shape, schema::PrimitiveType *type, | |||
| int *begin_norm_axis, int *begin_params_axis) const { | |||
| MS_ASSERT(input_node != nullptr); | |||
| MS_ASSERT(type != nullptr); | |||
| MS_ASSERT(begin_norm_axis != nullptr); | |||
| MS_ASSERT(begin_params_axis != nullptr); | |||
| auto abstract = input_cnode->abstract(); | |||
| @@ -170,30 +150,44 @@ bool LayerNormFusion::GetAxis(const CNodePtr &input_cnode, const std::vector<int | |||
| return false; | |||
| } | |||
| auto shape = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape(); | |||
| if (mean_axes.back() + 1 != static_cast<int>(shape.size())) { | |||
| MS_LOG(DEBUG) << "mean node is not reduce to last axis"; | |||
| return false; | |||
| } | |||
| for (size_t i = 1; i < mean_axes.size(); ++i) { | |||
| if (mean_axes[i] != mean_axes[i - 1] + 1) { | |||
| MS_LOG(DEBUG) << "mean axes is not continuous"; | |||
| return false; | |||
| } | |||
| } | |||
| if (shape.size() == 4 && mean_axes.size() == 2 && mean_axes[0] == 1 && mean_axes[1] == 2) { | |||
| if (params_shape.size() == 1 && params_shape.back() == shape.back()) { | |||
| *type = schema::PrimitiveType_InstanceNorm; | |||
| return true; | |||
| } | |||
| } | |||
| if (mean_axes.back() >= 0 && mean_axes.back() + 1 != static_cast<int>(shape.size())) { | |||
| MS_LOG(DEBUG) << "mean node is not reduce to last axis"; | |||
| return false; | |||
| } | |||
| // there is no need to check params_shape | |||
| *begin_norm_axis = mean_axes.front(); | |||
| *begin_params_axis = static_cast<int>(shape.size()) - static_cast<int>(params_shape.size()); | |||
| if (*begin_params_axis < 0) { | |||
| MS_LOG(DEBUG) << "LayerNorm begin_params_axis illegal, not fuse"; | |||
| return false; | |||
| if (*begin_norm_axis >= 0) { | |||
| *begin_params_axis = static_cast<int>(shape.size()) - static_cast<int>(params_shape.size()); | |||
| if (*begin_params_axis < 0) { | |||
| MS_LOG(DEBUG) << "LayerNorm begin_params_axis illegal, not fuse"; | |||
| return false; | |||
| } | |||
| } else { | |||
| *begin_params_axis = -static_cast<int>(params_shape.size()); | |||
| } | |||
| *type = schema::PrimitiveType_LayerNormFusion; | |||
| return true; | |||
| } | |||
| bool LayerNormFusion::CheckPattern(const EquivPtr &equiv, float *epsilon, int *begin_norm_axis, | |||
| int *begin_params_axis) const { | |||
| bool TfNormFusion::CheckPattern(const EquivPtr &equiv, schema::PrimitiveType *type, float *epsilon, | |||
| int *begin_norm_axis, int *begin_params_axis) const { | |||
| MS_ASSERT(equiv != nullptr); | |||
| MS_ASSERT(epsilon != nullptr); | |||
| MS_ASSERT(type != nullptr); | |||
| MS_ASSERT(begin_norm_axis != nullptr); | |||
| MS_ASSERT(begin_params_axis != nullptr); | |||
| // beta | |||
| @@ -243,9 +237,6 @@ bool LayerNormFusion::CheckPattern(const EquivPtr &equiv, float *epsilon, int *b | |||
| if (mean1_axes != mean2_axes) { | |||
| return false; | |||
| } | |||
| if (mean1_axes.size() != gamma_shape.size() || mean1_axes.size() != beta_shape.size()) { | |||
| return false; | |||
| } | |||
| if (gamma_shape != beta_shape) { | |||
| return false; | |||
| } | |||
| @@ -254,14 +245,14 @@ bool LayerNormFusion::CheckPattern(const EquivPtr &equiv, float *epsilon, int *b | |||
| } else { | |||
| return false; | |||
| } | |||
| if (!GetAxis(input_cnode, mean1_axes, gamma_shape, begin_norm_axis, begin_params_axis)) { | |||
| if (!GetNormTypeAndAxis(input_cnode, mean1_axes, gamma_shape, type, begin_norm_axis, begin_params_axis)) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| const AnfNodePtr LayerNormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &equiv) const { | |||
| const AnfNodePtr TfNormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &equiv) const { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| MS_ASSERT(equiv != nullptr); | |||
| @@ -273,14 +264,24 @@ const AnfNodePtr LayerNormFusion::Process(const FuncGraphPtr &func_graph, const | |||
| float epsilon = 0.0f; | |||
| int begin_norm_axis = 0; | |||
| int begin_params_axis = 0; | |||
| if (!CheckPattern(equiv, &epsilon, &begin_norm_axis, &begin_params_axis)) { | |||
| schema::PrimitiveType type = schema::PrimitiveType_NONE; | |||
| if (!CheckPattern(equiv, &type, &epsilon, &begin_norm_axis, &begin_params_axis)) { | |||
| return nullptr; | |||
| } | |||
| auto norm_cnode = CreateNormNode(func_graph, equiv, type, epsilon, begin_norm_axis, begin_params_axis); | |||
| if (norm_cnode == nullptr) { | |||
| MS_LOG(DEBUG) << "create norm cnode failed"; | |||
| return nullptr; | |||
| } | |||
| auto layer_norm_cnode = CreateLayerNormNode(func_graph, equiv, epsilon, begin_norm_axis, begin_params_axis); | |||
| layer_norm_cnode->set_abstract(add2_cnode->abstract()->Clone()); | |||
| layer_norm_cnode->set_fullname_with_scope("layer_norm_" + add2_cnode->fullname_with_scope()); | |||
| MS_LOG(INFO) << "layernorm node:" << layer_norm_cnode->fullname_with_scope() << " fusion success"; | |||
| return layer_norm_cnode; | |||
| norm_cnode->set_abstract(add2_cnode->abstract()->Clone()); | |||
| if (type == schema::PrimitiveType_LayerNormFusion) { | |||
| norm_cnode->set_fullname_with_scope("layer_norm_" + add2_cnode->fullname_with_scope()); | |||
| MS_LOG(INFO) << "layer_norm node:" << norm_cnode->fullname_with_scope() << " fusion success"; | |||
| } else if (type == schema::PrimitiveType_InstanceNorm) { | |||
| norm_cnode->set_fullname_with_scope("instance_norm_" + add2_cnode->fullname_with_scope()); | |||
| MS_LOG(INFO) << "instance_norm node:" << norm_cnode->fullname_with_scope() << " fusion success"; | |||
| } | |||
| return norm_cnode; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,107 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "schema/inner/model_generated.h" | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "utils/utils.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| /// fuse layer_norm, instance_norm into one operator | |||
| class TfNormFusion : public PatternProcessPass { | |||
| public: | |||
| explicit TfNormFusion(const std::string &name = "tf_norm_fusion", bool multigraph = true) | |||
| : PatternProcessPass(name, multigraph) { | |||
| input_ = std::make_shared<Var>(); | |||
| mean1_ = std::make_shared<Var>(); | |||
| mean1_axes_ = std::make_shared<Var>(); | |||
| mean2_ = std::make_shared<Var>(); | |||
| mean2_axes_ = std::make_shared<Var>(); | |||
| gamma_ = std::make_shared<Var>(); | |||
| beta_ = std::make_shared<Var>(); | |||
| epsilon_ = std::make_shared<Var>(); | |||
| } | |||
| ~TfNormFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| bool GetNormTypeAndAxis(const CNodePtr &input_cnode, const std::vector<int> &mean_axes, | |||
| const std::vector<int> ¶ms_shape, schema::PrimitiveType *type, int *begin_norm_axis, | |||
| int *begin_params_axis) const; | |||
| bool CheckPattern(const EquivPtr &equiv, schema::PrimitiveType *type, float *epsilon, int *begin_norm_axis, | |||
| int *begin_params_axis) const; | |||
| CNodePtr CreateNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const schema::PrimitiveType type, | |||
| float epsilon, int begin_norm_axis, int begin_params_axis) const; | |||
| protected: | |||
| VarPtr input_ = nullptr; | |||
| VarPtr mean1_ = nullptr; | |||
| VarPtr mean1_axes_ = nullptr; | |||
| VarPtr mean2_ = nullptr; | |||
| VarPtr mean2_axes_ = nullptr; | |||
| VarPtr gamma_ = nullptr; | |||
| VarPtr beta_ = nullptr; | |||
| VarPtr epsilon_ = nullptr; | |||
| }; | |||
| inline bool IsAddNode(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimAddFusion); | |||
| } | |||
| return false; | |||
| } | |||
| inline bool IsSquaredDifferenceNode(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSquaredDifference); | |||
| } | |||
| return false; | |||
| } | |||
| inline bool IsRsqrtNode(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimRsqrt); | |||
| } | |||
| return false; | |||
| } | |||
| inline bool IsMulNode(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimMulFusion); | |||
| } | |||
| return false; | |||
| } | |||
| inline bool IsSubNode(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSubFusion); | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_ | |||