| @@ -46,6 +46,8 @@ | |||||
| #include "pre_activate/ascend/ir_fusion/mul_addn_fusion.h" | #include "pre_activate/ascend/ir_fusion/mul_addn_fusion.h" | ||||
| #include "pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h" | #include "pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h" | ||||
| #include "pre_activate/ascend/ir_fusion/remove_reshape_pair.h" | #include "pre_activate/ascend/ir_fusion/remove_reshape_pair.h" | ||||
| #include "pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h" | |||||
| #include "pre_activate/ascend/ir_fusion/derelu_fusion.h" | |||||
| #include "pre_activate/ascend/format_type/insert_trans_op.h" | #include "pre_activate/ascend/format_type/insert_trans_op.h" | ||||
| #include "pre_activate/pass/getitem_tuple.h" | #include "pre_activate/pass/getitem_tuple.h" | ||||
| #include "pre_activate/pass/optimize_dependence.h" | #include "pre_activate/pass/optimize_dependence.h" | ||||
| @@ -94,8 +96,10 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { | |||||
| ir_fusion_pm->AddPass(std::make_shared<MulAddNFusion>()); | ir_fusion_pm->AddPass(std::make_shared<MulAddNFusion>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<MatmulBiasaddFusion>()); | ir_fusion_pm->AddPass(std::make_shared<MatmulBiasaddFusion>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<AddnFission>()); | ir_fusion_pm->AddPass(std::make_shared<AddnFission>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<DereluFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<ConfusionMulGradFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>()); | ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <string> | |||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| #include "utils/utils.h" | #include "utils/utils.h" | ||||
| @@ -89,6 +90,9 @@ const AnfNodePtr ConfusionMulGradFusion::Process(const FuncGraphPtr &graph, cons | |||||
| auto reduce_sum = node->cast<CNodePtr>(); | auto reduce_sum = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(reduce_sum); | MS_EXCEPTION_IF_NULL(reduce_sum); | ||||
| auto mul1 = reduce_sum->input(1); | auto mul1 = reduce_sum->input(1); | ||||
| if (mul1->fullname_with_scope().find("bert/encoder") == std::string::npos) { | |||||
| return nullptr; | |||||
| } | |||||
| if (IsUsedByOthers(graph, mul1)) { | if (IsUsedByOthers(graph, mul1)) { | ||||
| MS_LOG(INFO) << "Mul1 is used by others, quit fusion!"; | MS_LOG(INFO) << "Mul1 is used by others, quit fusion!"; | ||||
| return nullptr; | return nullptr; | ||||
| @@ -50,9 +50,22 @@ CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) { | |||||
| MS_EXCEPTION_IF_NULL(new_node); | MS_EXCEPTION_IF_NULL(new_node); | ||||
| new_node->set_scope(relu->scope()); | new_node->set_scope(relu->scope()); | ||||
| // ReluV2's 2rd output is mask whose data type is uint8 and value is 0 or 1, so shape is an empty vector | |||||
| // ReluV2's 2rd output is mask whose data type is uint8 | |||||
| TypeId mask_dtype = kNumberTypeUInt8; | TypeId mask_dtype = kNumberTypeUInt8; | ||||
| std::vector<size_t> mask_shape; | |||||
| std::vector<size_t> mask_shape = AnfAlgo::GetOutputInferShape(relu, 0); | |||||
| if (mask_shape.size() != 4) { | |||||
| MS_LOG(WARNING) << "relu's infer shape size not equal 4"; | |||||
| return nullptr; | |||||
| } | |||||
| auto input_dtype = AnfAlgo::GetPrevNodeOutputInferDataType(relu, 0); | |||||
| if (input_dtype == kNumberTypeUInt8 || input_dtype == kNumberTypeInt8) { | |||||
| mask_shape[1] = (mask_shape[1] + 31) / 32; | |||||
| mask_shape.push_back(4); | |||||
| } else { | |||||
| mask_shape[1] = (mask_shape[1] + 15) / 16; | |||||
| mask_shape.push_back(2); | |||||
| } | |||||
| auto types = {AnfAlgo::GetOutputInferDataType(relu, 0), mask_dtype}; | auto types = {AnfAlgo::GetOutputInferDataType(relu, 0), mask_dtype}; | ||||
| auto shapes = {AnfAlgo::GetOutputInferShape(relu, 0), mask_shape}; | auto shapes = {AnfAlgo::GetOutputInferShape(relu, 0), mask_shape}; | ||||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, new_node.get()); | AnfAlgo::SetOutputInferTypeAndShape(types, shapes, new_node.get()); | ||||
| @@ -91,6 +104,9 @@ const AnfNodePtr DereluFusion::Process(const FuncGraphPtr &graph, const AnfNodeP | |||||
| MS_EXCEPTION_IF_NULL(relu); | MS_EXCEPTION_IF_NULL(relu); | ||||
| auto relu_v2 = CreateReluV2(graph, relu); | auto relu_v2 = CreateReluV2(graph, relu); | ||||
| if (relu_v2 == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| std::vector<AnfNodePtr> relu_v2_node_outputs; | std::vector<AnfNodePtr> relu_v2_node_outputs; | ||||
| CreateMultipleOutputsOfAnfNode(graph, relu_v2, kReluV2OutputNum, &relu_v2_node_outputs); | CreateMultipleOutputsOfAnfNode(graph, relu_v2, kReluV2OutputNum, &relu_v2_node_outputs); | ||||
| @@ -120,7 +120,7 @@ constexpr auto kStreamActiveOpName = "StreamActive"; | |||||
| constexpr auto kAssignAddOpName = "AssignAdd"; | constexpr auto kAssignAddOpName = "AssignAdd"; | ||||
| constexpr auto kSendOpName = "Send"; | constexpr auto kSendOpName = "Send"; | ||||
| constexpr auto kRecvOpName = "Recv"; | constexpr auto kRecvOpName = "Recv"; | ||||
| constexpr auto kReluV2OpName = "ReluV2"; | |||||
| constexpr auto kReluV2OpName = "ReLUV2"; | |||||
| constexpr auto kReluGradV2OpName = "ReluGradV2"; | constexpr auto kReluGradV2OpName = "ReluGradV2"; | ||||
| // attr key name | // attr key name | ||||
| @@ -32,6 +32,11 @@ class TestHWOptimizeConfusionMulGradFusion : public BackendCommon { | |||||
| TEST_F(TestHWOptimizeConfusionMulGradFusion, test_fusion) { | TEST_F(TestHWOptimizeConfusionMulGradFusion, test_fusion) { | ||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_confusion_mul_grad_fusion", "before"); | FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_confusion_mul_grad_fusion", "before"); | ||||
| EXPECT_NE(g, nullptr); | EXPECT_NE(g, nullptr); | ||||
| auto bert_scope = std::make_shared<Scope>("bert/encoder"); | |||||
| for (auto node : TopoSort(g->get_return())) { | |||||
| node->set_scope(bert_scope); | |||||
| } | |||||
| std::vector<int> shp{1, 1, 1, 1}; | std::vector<int> shp{1, 1, 1, 1}; | ||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | ||||
| AbstractBasePtrList args_spec_list; | AbstractBasePtrList args_spec_list; | ||||
| @@ -17,7 +17,7 @@ from mindspore.ops import Primitive | |||||
| relu = P.ReLU() | relu = P.ReLU() | ||||
| relu_grad = Primitive('ReluGrad') | relu_grad = Primitive('ReluGrad') | ||||
| relu_v2 = Primitive('ReluV2') | |||||
| relu_v2 = Primitive('ReLUV2') | |||||
| relu_grad_v2 = Primitive('ReluGradV2') | relu_grad_v2 = Primitive('ReluGradV2') | ||||
| make_tuple = Primitive('make_tuple') | make_tuple = Primitive('make_tuple') | ||||
| tuple_getitem = Primitive('tuple_getitem') | tuple_getitem = Primitive('tuple_getitem') | ||||