|
|
|
@@ -71,11 +71,11 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn |
|
|
|
{split_v_output0_shape, split_v_output1_shape}, split_v.get()); |
|
|
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrSizeSplits, |
|
|
|
MakeValue(std::vector<int>{SizeToInt((origin_output2_shape[2] + 15) / 16), |
|
|
|
SizeToInt((origin_output3_shape[1] + 15) / 16)}), |
|
|
|
MakeValue(std::vector<int64_t>{SizeToLong((origin_output2_shape[2] + 15) / 16), |
|
|
|
SizeToLong((origin_output3_shape[1] + 15) / 16)}), |
|
|
|
split_v); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(0), split_v); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(2), split_v); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast<int64_t>(0)), split_v); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(static_cast<int64_t>(2)), split_v); |
|
|
|
|
|
|
|
basic_lstm_cell_c_state_grad_nodes.emplace_back(basic_lstm_cell_c_state_grad); |
|
|
|
matmul_nodes.emplace_back(matmul); |
|
|
|
@@ -88,15 +88,15 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn |
|
|
|
|
|
|
|
AnfNodePtr CreateLSTMSPlitV(const FuncGraphPtr &func_graph, const AnfNodePtr &input, |
|
|
|
const std::vector<std::vector<size_t>> &split_shapes, |
|
|
|
const std::vector<TypeId> &split_types, const std::vector<int> &size_split, |
|
|
|
const std::vector<TypeId> &split_types, const std::vector<int64_t> &size_split, |
|
|
|
size_t num_split_x) { |
|
|
|
std::vector<AnfNodePtr> lstm_split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())), |
|
|
|
input}; |
|
|
|
auto lstm_split = func_graph->NewCNode(lstm_split_input); |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(split_types, split_shapes, lstm_split.get()); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(size_split), lstm_split); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(0), lstm_split); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(SizeToInt(num_split_x)), lstm_split); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast<int64_t>(0)), lstm_split); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(SizeToLong(num_split_x)), lstm_split); |
|
|
|
return lstm_split; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -112,7 +112,7 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr & |
|
|
|
size_t num_split_x = AnfAlgo::GetOutputInferShape(origin_input7, 0)[0]; |
|
|
|
std::vector<std::vector<size_t>> split_shapes; |
|
|
|
std::vector<TypeId> split_types; |
|
|
|
std::vector<int> size_split; |
|
|
|
std::vector<int64_t> size_split; |
|
|
|
for (size_t i = 0; i < num_split_x; ++i) { |
|
|
|
split_shapes.emplace_back(split_c_dims); |
|
|
|
split_types.emplace_back(kNumberTypeFloat32); |
|
|
|
@@ -238,9 +238,9 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr & |
|
|
|
auto lstm_x_concat = func_graph->NewCNode(lstm_x_concat_input); |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 2)}, |
|
|
|
lstm_x_concat.get()); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(num_split_x)), lstm_x_concat); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{SizeToInt(num_split_x)}), lstm_x_concat); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(0), lstm_x_concat); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(num_split_x)), lstm_x_concat); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(num_split_x)}), lstm_x_concat); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(static_cast<int64_t>(0)), lstm_x_concat); |
|
|
|
|
|
|
|
// Create lstm_gage_concat |
|
|
|
auto lstm_gage_concat = func_graph->NewCNode(lstm_gage_concat_input); |
|
|
|
@@ -248,8 +248,8 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr & |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, |
|
|
|
{{origin_input7_shape[0], origin_input7_shape[1], 4 * origin_input7_shape[2]}}, |
|
|
|
lstm_gage_concat.get()); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(num_split_x)), lstm_gage_concat); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{SizeToInt(num_split_x)}), lstm_gage_concat); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(num_split_x)), lstm_gage_concat); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(num_split_x)}), lstm_gage_concat); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(0), lstm_gage_concat); |
|
|
|
|
|
|
|
outputs->emplace_back(lstm_x_concat); |
|
|
|
@@ -274,9 +274,10 @@ AnfNodePtr CreateSplitV(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_ |
|
|
|
std::vector<std::vector<size_t>> shapes = {shape1, shape2}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_v.get()); |
|
|
|
// Set attr |
|
|
|
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(0), split_v); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(2), split_v); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(std::vector<int>{SizeToInt(origin_input6_shape[0] - 1), 1}), split_v); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(SizeToLong(0)), split_v); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(SizeToLong(2)), split_v); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(std::vector<int64_t>{SizeToLong(origin_input6_shape[0] - 1), 1}), |
|
|
|
split_v); |
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), split_v); |
|
|
|
return split_v; |
|
|
|
} |
|
|
|
@@ -315,9 +316,9 @@ AnfNodePtr CreateHConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic |
|
|
|
std::vector<size_t> shape = {splitv_output0_shape[0] + 1, origin_input4_shape[0], origin_input4_shape[1]}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape}, concat.get()); |
|
|
|
// Set attr |
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{2}), concat); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(0), concat); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(2)), concat); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{2}), concat); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(0)), concat); |
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat); |
|
|
|
return concat; |
|
|
|
} |
|
|
|
@@ -338,9 +339,9 @@ AnfNodePtr CreateConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_ |
|
|
|
origin_output0_shape[2] + h_concat_output_shape[2]}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input0, 0)}, {shape}, concat.get()); |
|
|
|
// Set attr |
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{2}), concat); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(2), concat); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(2)), concat); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{2}), concat); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(2)), concat); |
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat); |
|
|
|
return concat; |
|
|
|
} |
|
|
|
@@ -373,9 +374,9 @@ AnfNodePtr CreateConcatNodeT1(const FuncGraphPtr &func_graph, const CNodePtr &dy |
|
|
|
std::vector<size_t> shape = {origin_input0_shape[0], origin_input0_shape[1], origin_input0_shape[2] + shape_tmp[2]}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input0, 0)}, {shape}, concat.get()); |
|
|
|
// Set attr |
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{2}), concat); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(2), concat); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(2)), concat); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{2}), concat); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(2)), concat); |
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat); |
|
|
|
return concat; |
|
|
|
} |
|
|
|
@@ -410,7 +411,7 @@ AnfNodePtr CreateDwReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dyn |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)}, |
|
|
|
{AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 0)}, reduce_sum.get()); |
|
|
|
// Set attr |
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int>{0}), reduce_sum); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0}), reduce_sum); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(false), reduce_sum); |
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum); |
|
|
|
return reduce_sum; |
|
|
|
@@ -427,7 +428,7 @@ AnfNodePtr CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dyn |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 1)}, |
|
|
|
{AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 1)}, reduce_sum.get()); |
|
|
|
// Set attr |
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int>{0, 1}), reduce_sum); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0, 1}), reduce_sum); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(false), reduce_sum); |
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum); |
|
|
|
return reduce_sum; |
|
|
|
|