|
|
@@ -23,7 +23,6 @@ namespace opt { |
|
|
namespace { |
|
|
namespace { |
|
|
constexpr size_t kDynamicRNNGradInputNum = 16; |
|
|
constexpr size_t kDynamicRNNGradInputNum = 16; |
|
|
constexpr size_t kSplitVOutputNum = 2; |
|
|
constexpr size_t kSplitVOutputNum = 2; |
|
|
constexpr size_t kLSTMInputGradOutputNum = 4; |
|
|
|
|
|
|
|
|
|
|
|
void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode, |
|
|
void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode, |
|
|
std::vector<std::vector<AnfNodePtr>> *result_nodes) { |
|
|
std::vector<std::vector<AnfNodePtr>> *result_nodes) { |
|
|
@@ -34,8 +33,8 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn |
|
|
std::vector<AnfNodePtr> matmul_nodes; |
|
|
std::vector<AnfNodePtr> matmul_nodes; |
|
|
std::vector<AnfNodePtr> split_nodes; |
|
|
std::vector<AnfNodePtr> split_nodes; |
|
|
// Get the size of t |
|
|
// Get the size of t |
|
|
auto origin_input9_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(10), 0); |
|
|
|
|
|
size_t t_size = origin_input9_shape[0]; |
|
|
|
|
|
|
|
|
auto origin_input9_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(11), 0); |
|
|
|
|
|
size_t t_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(9), 0)[0]; |
|
|
auto input_i_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(12), 0); |
|
|
auto input_i_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(12), 0); |
|
|
|
|
|
|
|
|
for (size_t i = 0; i < t_size; ++i) { |
|
|
for (size_t i = 0; i < t_size; ++i) { |
|
|
@@ -55,7 +54,7 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn |
|
|
auto origin_input1_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(2), 0); |
|
|
auto origin_input1_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(2), 0); |
|
|
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name()))}; |
|
|
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name()))}; |
|
|
auto matmul = func_graph->NewCNode(matmul_inputs); |
|
|
auto matmul = func_graph->NewCNode(matmul_inputs); |
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {{output0_dims[0], origin_input1_shape[0]}}, |
|
|
|
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {{IntToSize(1), output0_dims[0], origin_input1_shape[0]}}, |
|
|
matmul.get()); |
|
|
matmul.get()); |
|
|
AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(false), matmul); |
|
|
AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(false), matmul); |
|
|
AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(true), matmul); |
|
|
AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(true), matmul); |
|
|
@@ -65,8 +64,8 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn |
|
|
auto split_v = func_graph->NewCNode(splitv_input); |
|
|
auto split_v = func_graph->NewCNode(splitv_input); |
|
|
auto origin_output2_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 2); |
|
|
auto origin_output2_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 2); |
|
|
auto origin_output3_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 3); |
|
|
auto origin_output3_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 3); |
|
|
std::vector<size_t> split_v_output0_shape{origin_output2_shape[1], origin_output2_shape[2]}; |
|
|
|
|
|
std::vector<size_t> split_v_output1_shape{origin_output3_shape[0], origin_output3_shape[1]}; |
|
|
|
|
|
|
|
|
std::vector<size_t> split_v_output0_shape{IntToSize(1), origin_output2_shape[1], origin_output2_shape[2]}; |
|
|
|
|
|
std::vector<size_t> split_v_output1_shape{IntToSize(1), origin_output3_shape[0], origin_output3_shape[1]}; |
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32}, |
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32}, |
|
|
{split_v_output0_shape, split_v_output1_shape}, split_v.get()); |
|
|
{split_v_output0_shape, split_v_output1_shape}, split_v.get()); |
|
|
|
|
|
|
|
|
@@ -74,7 +73,7 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn |
|
|
MakeValue(std::vector<int64_t>{SizeToLong((origin_output2_shape[2] + 15) / 16), |
|
|
MakeValue(std::vector<int64_t>{SizeToLong((origin_output2_shape[2] + 15) / 16), |
|
|
SizeToLong((origin_output3_shape[1] + 15) / 16)}), |
|
|
SizeToLong((origin_output3_shape[1] + 15) / 16)}), |
|
|
split_v); |
|
|
split_v); |
|
|
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast<int64_t>(0)), split_v); |
|
|
|
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast<int64_t>(1)), split_v); |
|
|
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(static_cast<int64_t>(2)), split_v); |
|
|
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(static_cast<int64_t>(2)), split_v); |
|
|
|
|
|
|
|
|
basic_lstm_cell_c_state_grad_nodes.emplace_back(basic_lstm_cell_c_state_grad); |
|
|
basic_lstm_cell_c_state_grad_nodes.emplace_back(basic_lstm_cell_c_state_grad); |
|
|
@@ -106,7 +105,7 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr & |
|
|
CreateTLoopNode(func_graph, dynamic_rnn_grad_cnode, &result_nodes); |
|
|
CreateTLoopNode(func_graph, dynamic_rnn_grad_cnode, &result_nodes); |
|
|
|
|
|
|
|
|
auto origin_input5_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(6), 0); |
|
|
auto origin_input5_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(6), 0); |
|
|
std::vector<size_t> split_c_dims{1, origin_input5_shape[0], origin_input5_shape[1]}; |
|
|
|
|
|
|
|
|
std::vector<size_t> split_c_dims{IntToSize(1), origin_input5_shape[0], origin_input5_shape[1]}; |
|
|
|
|
|
|
|
|
auto origin_input7 = dynamic_rnn_grad_cnode->input(8); |
|
|
auto origin_input7 = dynamic_rnn_grad_cnode->input(8); |
|
|
size_t num_split_x = AnfAlgo::GetOutputInferShape(origin_input7, 0)[0]; |
|
|
size_t num_split_x = AnfAlgo::GetOutputInferShape(origin_input7, 0)[0]; |
|
|
@@ -250,7 +249,7 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr & |
|
|
lstm_gage_concat.get()); |
|
|
lstm_gage_concat.get()); |
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(num_split_x)), lstm_gage_concat); |
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(num_split_x)), lstm_gage_concat); |
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(num_split_x)}), lstm_gage_concat); |
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(num_split_x)}), lstm_gage_concat); |
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(0), lstm_gage_concat); |
|
|
|
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(0)), lstm_gage_concat); |
|
|
|
|
|
|
|
|
outputs->emplace_back(lstm_x_concat); |
|
|
outputs->emplace_back(lstm_x_concat); |
|
|
outputs->emplace_back(pre_split_outputs[1]); |
|
|
outputs->emplace_back(pre_split_outputs[1]); |
|
|
@@ -305,9 +304,7 @@ AnfNodePtr CreateHConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic |
|
|
std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())), |
|
|
std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())), |
|
|
origin_input4}; |
|
|
origin_input4}; |
|
|
auto reshape = func_graph->NewCNode(reshape_input); |
|
|
auto reshape = func_graph->NewCNode(reshape_input); |
|
|
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape); |
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape_tmp}, reshape.get()); |
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape_tmp}, reshape.get()); |
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())), |
|
|
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())), |
|
|
reshape, splitv_outputs[0]}; |
|
|
reshape, splitv_outputs[0]}; |
|
|
auto concat = func_graph->NewCNode(concat_inputs); |
|
|
auto concat = func_graph->NewCNode(concat_inputs); |
|
|
@@ -363,7 +360,6 @@ AnfNodePtr CreateConcatNodeT1(const FuncGraphPtr &func_graph, const CNodePtr &dy |
|
|
std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())), |
|
|
std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())), |
|
|
origin_input4}; |
|
|
origin_input4}; |
|
|
auto reshape = func_graph->NewCNode(reshape_input); |
|
|
auto reshape = func_graph->NewCNode(reshape_input); |
|
|
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape); |
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape_tmp}, reshape.get()); |
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape_tmp}, reshape.get()); |
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())), |
|
|
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())), |
|
|
|