From: @xu_anyue Reviewed-by: @hangangqiang,@hangangqiang Signed-off-by: @hangangqiangpull/14175/MERGE
| @@ -538,6 +538,8 @@ inline const PrimitivePtr kPrimTileFusion = std::make_shared<Primitive>("TileFus | |||||
| inline const PrimitivePtr kPrimReduceFusion = std::make_shared<Primitive>("ReduceFusion"); | inline const PrimitivePtr kPrimReduceFusion = std::make_shared<Primitive>("ReduceFusion"); | ||||
| inline const PrimitivePtr kPrimLayerNormFusion = std::make_shared<Primitive>("LayerNormFusion"); | inline const PrimitivePtr kPrimLayerNormFusion = std::make_shared<Primitive>("LayerNormFusion"); | ||||
| inline const PrimitivePtr kPrimDType = std::make_shared<Primitive>("DType"); | inline const PrimitivePtr kPrimDType = std::make_shared<Primitive>("DType"); | ||||
| inline const PrimitivePtr kPrimDivFusion = std::make_shared<Primitive>("DivFusion"); | |||||
| inline const PrimitivePtr kPrimErf = std::make_shared<Primitive>("Erf"); | |||||
| class DoSignaturePrimitive : public Primitive { | class DoSignaturePrimitive : public Primitive { | ||||
| public: | public: | ||||
| @@ -34,8 +34,6 @@ using mindspore::lite::RET_OK; | |||||
| using mindspore::lite::STATUS; | using mindspore::lite::STATUS; | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| inline const PrimitivePtr kPrimDivFusion = std::make_shared<Primitive>("DivFusion"); | |||||
| inline const PrimitivePtr kPrimErf = std::make_shared<Primitive>("Erf"); | |||||
| inline const PrimitivePtr kPrimMakeTupleV2 = std::make_shared<Primitive>("make_tuple"); | inline const PrimitivePtr kPrimMakeTupleV2 = std::make_shared<Primitive>("make_tuple"); | ||||
| inline const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("Identity"); | inline const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("Identity"); | ||||
| constexpr auto kWeightFormat = "weight_format"; | constexpr auto kWeightFormat = "weight_format"; | ||||
| @@ -297,7 +297,8 @@ const BaseRef OnnxLayerNormFusion::DefinePattern() const { | |||||
| VectorRef add1_ref = | VectorRef add1_ref = | ||||
| VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), mean2_ref, epsilon_}); | VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), mean2_ref, epsilon_}); | ||||
| VectorRef sqrt_ref = VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSqrt>), add1_ref}); | VectorRef sqrt_ref = VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSqrt>), add1_ref}); | ||||
| VectorRef div_ref = VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&kPrimDivFusion>), sub1_ref, sqrt_ref}); | |||||
| VectorRef div_ref = | |||||
| VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimDivFusion>), sub1_ref, sqrt_ref}); | |||||
| VectorRef mul_ref = VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), gamma_, div_ref}); | VectorRef mul_ref = VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), gamma_, div_ref}); | ||||
| VectorRef add2_ref = VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), mul_ref, beta_}); | VectorRef add2_ref = VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), mul_ref, beta_}); | ||||
| return add2_ref; | return add2_ref; | ||||
| @@ -27,8 +27,8 @@ constexpr float MUL1_y = 0.5; | |||||
| // gelu(x) = 1/2 * x * [1 + erf(x / sqrt(2))] | // gelu(x) = 1/2 * x * [1 + erf(x / sqrt(2))] | ||||
| const BaseRef OnnxGeLUFusion::DefinePattern() const { | const BaseRef OnnxGeLUFusion::DefinePattern() const { | ||||
| VectorRef div_ref({std::make_shared<CondVar>(IsSpecifiedNode<&kPrimDivFusion>), input_, div_y_}); | |||||
| VectorRef erf_ref({std::make_shared<CondVar>(IsSpecifiedNode<&kPrimErf>), div_ref}); | |||||
| VectorRef div_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimDivFusion>), input_, div_y_}); | |||||
| VectorRef erf_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimErf>), div_ref}); | |||||
| VectorRef add_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), erf_ref, add_y_}); | VectorRef add_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), erf_ref, add_y_}); | ||||
| VectorRef mul1_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), input_, mul1_y_}); | VectorRef mul1_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), input_, mul1_y_}); | ||||
| VectorRef mul2_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), mul1_ref, add_ref}); | VectorRef mul2_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), mul1_ref, add_ref}); | ||||