From 5bf70b24bbcb6be9b84ba90c258f126a2f733b71 Mon Sep 17 00:00:00 2001 From: liubuyu Date: Mon, 16 Nov 2020 19:47:08 +0800 Subject: [PATCH] adjust dynamic_rnn_grad_fission_v2 position --- .../ascend/ascend_backend_optimization.cc | 4 ++-- .../ir_fission/dynamic_rnn_grad_fission_v2.cc | 20 ++++++++----------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index 4f1a083b57..8759e549c2 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -279,11 +279,11 @@ void AscendBackendIRFusionOptimization(const std::shared_ptrAddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); AddAscendIRFusionRulesPass(ir_fusion_pm.get()); AddAscendIRFusionPass(ir_fusion_pm.get()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); if (context_ptr->get_param(MS_CTX_ENABLE_TASK_SINK) && context_ptr->get_param(MS_CTX_ENABLE_LOOP_SINK) && ConfigManager::GetInstance().iter_num() > 1) { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.cc index 2846938d0c..29bd17b1db 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.cc @@ -23,7 +23,6 @@ namespace opt { namespace { constexpr size_t kDynamicRNNGradInputNum = 16; constexpr size_t kSplitVOutputNum = 2; -constexpr size_t kLSTMInputGradOutputNum = 4; void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode, std::vector> *result_nodes) { @@ -34,8 +33,8 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn std::vector matmul_nodes; std::vector split_nodes; // Get the size of t - auto origin_input9_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(10), 0); - size_t t_size = origin_input9_shape[0]; + auto origin_input9_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(11), 0); + size_t t_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(9), 0)[0]; auto input_i_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(12), 0); for (size_t i = 0; i < t_size; ++i) { @@ -55,7 +54,7 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn auto origin_input1_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(2), 0); std::vector matmul_inputs = {NewValueNode(std::make_shared(prim::kPrimMatMul->name()))}; auto matmul = func_graph->NewCNode(matmul_inputs); - AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {{output0_dims[0], origin_input1_shape[0]}}, + AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {{IntToSize(1), output0_dims[0], origin_input1_shape[0]}}, matmul.get()); AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(false), matmul); AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(true), matmul); @@ -65,8 +64,8 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn auto split_v = func_graph->NewCNode(splitv_input); auto origin_output2_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 2); auto origin_output3_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 3); - std::vector split_v_output0_shape{origin_output2_shape[1], origin_output2_shape[2]}; - std::vector split_v_output1_shape{origin_output3_shape[0], origin_output3_shape[1]}; + std::vector split_v_output0_shape{IntToSize(1), origin_output2_shape[1], origin_output2_shape[2]}; + std::vector split_v_output1_shape{IntToSize(1), origin_output3_shape[0], origin_output3_shape[1]}; AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32}, {split_v_output0_shape, split_v_output1_shape}, split_v.get()); @@ -74,7 +73,7 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn MakeValue(std::vector{SizeToLong((origin_output2_shape[2] + 15) / 16), SizeToLong((origin_output3_shape[1] + 15) / 16)}), split_v); - AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast(0)), split_v); + AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast(1)), split_v); AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(static_cast(2)), split_v); basic_lstm_cell_c_state_grad_nodes.emplace_back(basic_lstm_cell_c_state_grad); @@ -106,7 +105,7 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr & CreateTLoopNode(func_graph, dynamic_rnn_grad_cnode, &result_nodes); auto origin_input5_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(6), 0); - std::vector split_c_dims{1, origin_input5_shape[0], origin_input5_shape[1]}; + std::vector split_c_dims{IntToSize(1), origin_input5_shape[0], origin_input5_shape[1]}; auto origin_input7 = dynamic_rnn_grad_cnode->input(8); size_t num_split_x = AnfAlgo::GetOutputInferShape(origin_input7, 0)[0]; @@ -250,7 +249,7 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr & lstm_gage_concat.get()); AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(num_split_x)), lstm_gage_concat); AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector{SizeToLong(num_split_x)}), lstm_gage_concat); - AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(0), lstm_gage_concat); + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(0)), lstm_gage_concat); outputs->emplace_back(lstm_x_concat); outputs->emplace_back(pre_split_outputs[1]); @@ -305,9 +304,7 @@ AnfNodePtr CreateHConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic std::vector reshape_input = {NewValueNode(std::make_shared(prim::kPrimReshape->name())), origin_input4}; auto reshape = func_graph->NewCNode(reshape_input); - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape); AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape_tmp}, reshape.get()); - std::vector concat_inputs = {NewValueNode(std::make_shared(prim::kPrimConcat->name())), reshape, splitv_outputs[0]}; auto concat = func_graph->NewCNode(concat_inputs); @@ -363,7 +360,6 @@ AnfNodePtr CreateConcatNodeT1(const FuncGraphPtr &func_graph, const CNodePtr &dy std::vector reshape_input = {NewValueNode(std::make_shared(prim::kPrimReshape->name())), origin_input4}; auto reshape = func_graph->NewCNode(reshape_input); - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape); AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape_tmp}, reshape.get()); std::vector concat_inputs = {NewValueNode(std::make_shared(prim::kPrimConcat->name())),