Browse Source

adjust dynamic_rnn_grad_fission_v2 position

tags/v1.1.0
liubuyu 5 years ago
parent
commit
5bf70b24bb
2 changed files with 10 additions and 14 deletions
  1. +2
    -2
      mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc
  2. +8
    -12
      mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.cc

+ 2
- 2
mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc View File

@@ -279,11 +279,11 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
}
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicRNN>());
ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicGRUV2>());
ir_fusion_pm->AddPass(std::make_shared<DynamicRnnGradFissionV2>());
AddAscendIRFusionRulesPass(ir_fusion_pm.get());
AddAscendIRFusionPass(ir_fusion_pm.get());
ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicRNN>());
ir_fusion_pm->AddPass(std::make_shared<DynamicRnnGradFissionV2>());

if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) && context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK) &&
ConfigManager::GetInstance().iter_num() > 1) {


+ 8
- 12
mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.cc View File

@@ -23,7 +23,6 @@ namespace opt {
namespace {
constexpr size_t kDynamicRNNGradInputNum = 16;
constexpr size_t kSplitVOutputNum = 2;
constexpr size_t kLSTMInputGradOutputNum = 4;

void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
std::vector<std::vector<AnfNodePtr>> *result_nodes) {
@@ -34,8 +33,8 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn
std::vector<AnfNodePtr> matmul_nodes;
std::vector<AnfNodePtr> split_nodes;
// Get the size of t
auto origin_input9_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(10), 0);
size_t t_size = origin_input9_shape[0];
auto origin_input9_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(11), 0);
size_t t_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(9), 0)[0];
auto input_i_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(12), 0);

for (size_t i = 0; i < t_size; ++i) {
@@ -55,7 +54,7 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn
auto origin_input1_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(2), 0);
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name()))};
auto matmul = func_graph->NewCNode(matmul_inputs);
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {{output0_dims[0], origin_input1_shape[0]}},
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {{IntToSize(1), output0_dims[0], origin_input1_shape[0]}},
matmul.get());
AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(false), matmul);
AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(true), matmul);
@@ -65,8 +64,8 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn
auto split_v = func_graph->NewCNode(splitv_input);
auto origin_output2_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 2);
auto origin_output3_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 3);
std::vector<size_t> split_v_output0_shape{origin_output2_shape[1], origin_output2_shape[2]};
std::vector<size_t> split_v_output1_shape{origin_output3_shape[0], origin_output3_shape[1]};
std::vector<size_t> split_v_output0_shape{IntToSize(1), origin_output2_shape[1], origin_output2_shape[2]};
std::vector<size_t> split_v_output1_shape{IntToSize(1), origin_output3_shape[0], origin_output3_shape[1]};
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32},
{split_v_output0_shape, split_v_output1_shape}, split_v.get());

@@ -74,7 +73,7 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn
MakeValue(std::vector<int64_t>{SizeToLong((origin_output2_shape[2] + 15) / 16),
SizeToLong((origin_output3_shape[1] + 15) / 16)}),
split_v);
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast<int64_t>(0)), split_v);
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast<int64_t>(1)), split_v);
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(static_cast<int64_t>(2)), split_v);

basic_lstm_cell_c_state_grad_nodes.emplace_back(basic_lstm_cell_c_state_grad);
@@ -106,7 +105,7 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
CreateTLoopNode(func_graph, dynamic_rnn_grad_cnode, &result_nodes);

auto origin_input5_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(6), 0);
std::vector<size_t> split_c_dims{1, origin_input5_shape[0], origin_input5_shape[1]};
std::vector<size_t> split_c_dims{IntToSize(1), origin_input5_shape[0], origin_input5_shape[1]};

auto origin_input7 = dynamic_rnn_grad_cnode->input(8);
size_t num_split_x = AnfAlgo::GetOutputInferShape(origin_input7, 0)[0];
@@ -250,7 +249,7 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
lstm_gage_concat.get());
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(num_split_x)), lstm_gage_concat);
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(num_split_x)}), lstm_gage_concat);
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(0), lstm_gage_concat);
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(0)), lstm_gage_concat);

outputs->emplace_back(lstm_x_concat);
outputs->emplace_back(pre_split_outputs[1]);
@@ -305,9 +304,7 @@ AnfNodePtr CreateHConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic
std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
origin_input4};
auto reshape = func_graph->NewCNode(reshape_input);
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape);
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape_tmp}, reshape.get());

std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())),
reshape, splitv_outputs[0]};
auto concat = func_graph->NewCNode(concat_inputs);
@@ -363,7 +360,6 @@ AnfNodePtr CreateConcatNodeT1(const FuncGraphPtr &func_graph, const CNodePtr &dy
std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
origin_input4};
auto reshape = func_graph->NewCNode(reshape_input);
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape);
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape_tmp}, reshape.get());

std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())),


Loading…
Cancel
Save