| @@ -29,6 +29,8 @@ using std::vector; | |||||
| namespace ge { | namespace ge { | ||||
| namespace { | namespace { | ||||
| const size_t kGenMaskInputIndex = 1; | const size_t kGenMaskInputIndex = 1; | ||||
| const size_t K_GEN_MASK_FUSED_INPUT_INDEX1 = 2; | |||||
| const size_t K_GEN_MASK_FUSED_INPUT_INDEX2 = 3; | |||||
| const size_t kDefaultMaxParallelNum = 1; | const size_t kDefaultMaxParallelNum = 1; | ||||
| } // namespace | } // namespace | ||||
| @@ -93,10 +95,29 @@ bool LinkGenMaskNodesPass::AreAllInputsConst(const NodePtr &node) const { | |||||
| return true; | return true; | ||||
| } | } | ||||
| void GetMatMulFusionNodes(const NodePtr &node, NodePtr &gen_mask) { | |||||
| // "batch_matmul + dropout_do_mask" is transformed to batch_matmul in a ub fusion pass | |||||
| // node gen_mask is located at different place in the fused node | |||||
| auto in_data_nodes = node->GetInDataNodes(); | |||||
| if (in_data_nodes.size() > K_GEN_MASK_FUSED_INPUT_INDEX1 && node->GetType() == "BatchMatMul") { | |||||
| NodePtr &gen_mask_candidate = in_data_nodes.at(K_GEN_MASK_FUSED_INPUT_INDEX1); | |||||
| if (gen_mask_candidate->GetName().find("DropOutGenMaskV3") != gen_mask_candidate->GetName().npos) { | |||||
| gen_mask = gen_mask_candidate; | |||||
| } | |||||
| } else if (in_data_nodes.size() > K_GEN_MASK_FUSED_INPUT_INDEX2 && node->GetType() == "MatMulV2") { | |||||
| NodePtr &gen_mask_candidate = in_data_nodes.at(K_GEN_MASK_FUSED_INPUT_INDEX2); | |||||
| if (gen_mask_candidate->GetName().find("DropOutGenMaskV3") != gen_mask_candidate->GetName().npos) { | |||||
| gen_mask = gen_mask_candidate; | |||||
| } | |||||
| } | |||||
| } | |||||
| void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector<NodePtr> &gen_mask_nodes) const { | void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector<NodePtr> &gen_mask_nodes) const { | ||||
| set<NodePtr> nodes_set; | set<NodePtr> nodes_set; | ||||
| for (const NodePtr &node : graph->GetDirectNode()) { | for (const NodePtr &node : graph->GetDirectNode()) { | ||||
| if (node->GetType() != DROPOUTDOMASK && node->GetType() != DROPOUTDOMASKV3 && node->GetType() != DROPOUTDOMASKV3D) { | |||||
| bool not_dropout_do_mask_flag = (node->GetType() != DROPOUTDOMASK && node->GetType() != DROPOUTDOMASKV3 && | |||||
| node->GetType() != DROPOUTDOMASKV3D && node->GetType() != "BatchMatMul" && node->GetType() != "MatMulV2"); | |||||
| if (not_dropout_do_mask_flag) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -107,6 +128,9 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector<Node | |||||
| auto in_data_nodes = node->GetInDataNodes(); | auto in_data_nodes = node->GetInDataNodes(); | ||||
| if (in_data_nodes.size() > kGenMaskInputIndex) { | if (in_data_nodes.size() > kGenMaskInputIndex) { | ||||
| NodePtr &gen_mask = in_data_nodes.at(kGenMaskInputIndex); | NodePtr &gen_mask = in_data_nodes.at(kGenMaskInputIndex); | ||||
| GetMatMulFusionNodes(node, gen_mask); | |||||
| if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { | if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { | ||||
| continue; | continue; | ||||
| } | } | ||||