|
|
|
@@ -15,9 +15,11 @@ |
|
|
|
*/ |
|
|
|
#include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.h" |
|
|
|
#include <vector> |
|
|
|
#include <string> |
|
|
|
#include <memory> |
|
|
|
#include "backend/session/kernel_graph.h" |
|
|
|
#include "backend/session/anf_runtime_algorithm.h" |
|
|
|
#include "backend/optimizer/ascend/ascend_helper.h" |
|
|
|
#include "utils/trace_base.h" |
|
|
|
#include "utils/tensor_construct_utils.h" |
|
|
|
|
|
|
|
@@ -34,9 +36,40 @@ constexpr int64_t kAttrAxis2Value = 2; |
|
|
|
constexpr int64_t kAttrNumSplitValue = 2; |
|
|
|
constexpr int64_t kAttrSplitDimValue = 2; |
|
|
|
constexpr size_t kDimMultiNum = 4; |
|
|
|
|
|
|
|
void SetAttrInputAndHiddenSize(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode, |
|
|
|
int64_t input_size, int64_t hidden_size) { |
|
|
|
auto input = dynamic_rnn_grad_cnode->input(kIndex2); |
|
|
|
MS_EXCEPTION_IF_NULL(input); |
|
|
|
// set for input |
|
|
|
while (input->isa<CNode>()) { |
|
|
|
AnfAlgo::SetNodeAttr(kAttrInputSize, MakeValue(input_size), input); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrHiddenSize, MakeValue(hidden_size), input); |
|
|
|
auto input_cnode = input->cast<CNodePtr>(); |
|
|
|
input = input_cnode->input(kIndex1); |
|
|
|
} |
|
|
|
if (input->isa<Parameter>()) { |
|
|
|
auto param = input->cast<ParameterPtr>(); |
|
|
|
param->set_input_size(input_size); |
|
|
|
param->set_hidden_size(hidden_size); |
|
|
|
} |
|
|
|
|
|
|
|
// set for output |
|
|
|
auto manager = func_graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
for (auto getitem_index : manager->node_users()[dynamic_rnn_grad_cnode]) { |
|
|
|
if (AnfAlgo::CheckPrimitiveType(getitem_index.first, prim::kPrimTupleGetItem)) { |
|
|
|
for (auto node_index : manager->node_users()[getitem_index.first]) { |
|
|
|
AnfAlgo::SetNodeAttr(kAttrInputSize, MakeValue(input_size), node_index.first); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrHiddenSize, MakeValue(hidden_size), node_index.first); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
void DynamicRnnGradFissionV2::CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode, |
|
|
|
const RNNShapeSpecs &specs, |
|
|
|
std::vector<std::vector<AnfNodePtr>> *result_nodes) const { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode); |
|
|
|
@@ -45,19 +78,15 @@ void DynamicRnnGradFissionV2::CreateTLoopNode(const FuncGraphPtr &func_graph, co |
|
|
|
std::vector<AnfNodePtr> matmul_nodes; |
|
|
|
std::vector<AnfNodePtr> split_nodes; |
|
|
|
// Get the size of t |
|
|
|
auto origin_input9_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex11), 0); |
|
|
|
size_t t_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex9), 0)[0]; |
|
|
|
auto input_i_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex12), 0); |
|
|
|
|
|
|
|
for (size_t i = 0; i < t_size; ++i) { |
|
|
|
for (size_t i = 0; i < specs.t_size; ++i) { |
|
|
|
// Create basic_lstm_cell_c_state_grad |
|
|
|
std::vector<AnfNodePtr> basic_lstm_cell_c_state_grad_inputs = { |
|
|
|
NewValueNode(std::make_shared<Primitive>(kBasicLSTMCellCStateGradV2OpName))}; |
|
|
|
auto basic_lstm_cell_c_state_grad = NewCNode(basic_lstm_cell_c_state_grad_inputs, func_graph); |
|
|
|
|
|
|
|
std::vector<size_t> output0_dims{ |
|
|
|
origin_input9_shape[kDim0], |
|
|
|
kDimMultiNum * (((origin_input9_shape[kDim1] + kCubeSize - 1) / kCubeSize) * kCubeSize)}; |
|
|
|
std::vector<size_t> output0_dims{specs.batch_size, kDimMultiNum * specs.hidden_nz_size * kCubeSize}; |
|
|
|
std::vector<size_t> output1_dims{input_i_shape[kDim1], input_i_shape[kDim2]}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16, kNumberTypeFloat32}, {output0_dims, output1_dims}, |
|
|
|
basic_lstm_cell_c_state_grad.get()); |
|
|
|
@@ -66,30 +95,40 @@ void DynamicRnnGradFissionV2::CreateTLoopNode(const FuncGraphPtr &func_graph, co |
|
|
|
|
|
|
|
// Create matmul |
|
|
|
auto origin_input1_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex2), 0); |
|
|
|
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name()))}; |
|
|
|
std::vector<AnfNodePtr> matmul_inputs; |
|
|
|
if (specs.shape_need_align) { |
|
|
|
matmul_inputs.push_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMulV2->name()))); |
|
|
|
} else { |
|
|
|
matmul_inputs.push_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name()))); |
|
|
|
} |
|
|
|
auto matmul = NewCNode(matmul_inputs, func_graph); |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {{IntToSize(1), output0_dims[0], origin_input1_shape[0]}}, |
|
|
|
matmul.get()); |
|
|
|
AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(false), matmul); |
|
|
|
AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(true), matmul); |
|
|
|
if (specs.shape_need_align) { |
|
|
|
AnfAlgo::SetNodeAttr(kAttrInputSize, MakeValue(SizeToLong(specs.input_size)), matmul); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrHiddenSize, MakeValue(SizeToLong(specs.hidden_size)), matmul); |
|
|
|
} |
|
|
|
|
|
|
|
// Create split |
|
|
|
std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))}; |
|
|
|
auto split_v = NewCNode(splitv_input, func_graph); |
|
|
|
auto origin_output2_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, kIndex2); |
|
|
|
auto origin_output3_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, kIndex3); |
|
|
|
std::vector<size_t> split_v_output0_shape{IntToSize(1), origin_output2_shape[kDim1], origin_output2_shape[kDim2]}; |
|
|
|
std::vector<size_t> split_v_output1_shape{IntToSize(1), origin_output3_shape[kDim0], origin_output3_shape[kDim1]}; |
|
|
|
std::vector<size_t> split_v_output0_shape{IntToSize(1), specs.batch_size, specs.input_size}; |
|
|
|
std::vector<size_t> split_v_output1_shape{IntToSize(1), specs.batch_size, specs.hidden_size}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32}, |
|
|
|
{split_v_output0_shape, split_v_output1_shape}, split_v.get()); |
|
|
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrSizeSplits, |
|
|
|
MakeValue(std::vector<int64_t>{ |
|
|
|
SizeToLong((origin_output2_shape[kDim2] + kCubeSize - 1) / kCubeSize * kCubeSize), |
|
|
|
SizeToLong((origin_output3_shape[kDim1] + kCubeSize - 1) / kCubeSize * kCubeSize)}), |
|
|
|
MakeValue(std::vector<int64_t>{SizeToLong(specs.input_nz_size * kCubeSize), |
|
|
|
SizeToLong(specs.hidden_nz_size * kCubeSize)}), |
|
|
|
split_v); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast<int64_t>(kAttrSplitDimValue)), split_v); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(static_cast<int64_t>(kAttrNumSplitValue)), split_v); |
|
|
|
if (specs.shape_need_align) { |
|
|
|
AnfAlgo::SetNodeAttr(kAttrFixedInputFormat, MakeValue(std::vector<string>{kOpFormat_FRAC_NZ}), split_v); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrFixedOutputFormat, MakeValue(std::vector<string>{kOpFormat_FRAC_NZ, kOpFormat_FRAC_NZ}), |
|
|
|
split_v); |
|
|
|
} |
|
|
|
|
|
|
|
basic_lstm_cell_c_state_grad_nodes.emplace_back(basic_lstm_cell_c_state_grad); |
|
|
|
matmul_nodes.emplace_back(matmul); |
|
|
|
@@ -117,7 +156,7 @@ AnfNodePtr DynamicRnnGradFissionV2::CreateLSTMSPlitV(const FuncGraphPtr &func_gr |
|
|
|
void DynamicRnnGradFissionV2::CreateTLoopNodeWithEdge(const FuncGraphPtr &func_graph, |
|
|
|
const CNodePtr &dynamic_rnn_grad_cnode, |
|
|
|
const std::vector<std::vector<AnfNodePtr>> &result_nodes, |
|
|
|
size_t num_split_x, |
|
|
|
size_t num_split_x, bool shape_need_align, |
|
|
|
std::vector<std::vector<AnfNodePtr>> *loop_node_outputs) const { |
|
|
|
auto &basic_lstm_cell_c_state_grad_nodes = result_nodes[kIndex0]; |
|
|
|
auto &matmul_nodes = result_nodes[kIndex1]; |
|
|
|
@@ -166,7 +205,6 @@ void DynamicRnnGradFissionV2::CreateTLoopNodeWithEdge(const FuncGraphPtr &func_g |
|
|
|
(void)basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_o_outputs[idx]); |
|
|
|
(void)basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_tanh_outputs[idx]); |
|
|
|
auto basic_lstm_cell_c_state_grad = NewCNode(basic_lstm_cell_c_state_grad_inputs, func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(basic_lstm_cell_c_state_grad); |
|
|
|
basic_lstm_cell_c_state_grad->set_abstract(basic_lstm_cell_c_state_grad_nodes[i]->abstract()); |
|
|
|
AnfAlgo::CopyNodeAttrs(basic_lstm_cell_c_state_grad_nodes[i], basic_lstm_cell_c_state_grad); |
|
|
|
// Create outputs for current basic_lstm_cell_c_state_grad node |
|
|
|
@@ -176,11 +214,15 @@ void DynamicRnnGradFissionV2::CreateTLoopNodeWithEdge(const FuncGraphPtr &func_g |
|
|
|
pre_basic_lstm_cell_c_state_grad_outputs = basic_lstm_cell_c_state_grad_outputs; |
|
|
|
|
|
|
|
// Create MatMul |
|
|
|
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name()))}; |
|
|
|
std::vector<AnfNodePtr> matmul_inputs; |
|
|
|
if (shape_need_align) { |
|
|
|
matmul_inputs.push_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMulV2->name()))); |
|
|
|
} else { |
|
|
|
matmul_inputs.push_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name()))); |
|
|
|
} |
|
|
|
(void)matmul_inputs.emplace_back(basic_lstm_cell_c_state_grad_outputs[0]); |
|
|
|
(void)matmul_inputs.emplace_back(dynamic_rnn_grad_cnode->input(kIndex2)); |
|
|
|
auto matmul = NewCNode(matmul_inputs, func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(matmul); |
|
|
|
matmul->set_abstract(matmul_nodes[i]->abstract()); |
|
|
|
AnfAlgo::CopyNodeAttrs(matmul_nodes[i], matmul); |
|
|
|
|
|
|
|
@@ -188,7 +230,6 @@ void DynamicRnnGradFissionV2::CreateTLoopNodeWithEdge(const FuncGraphPtr &func_g |
|
|
|
std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())), |
|
|
|
matmul}; |
|
|
|
auto split_v = NewCNode(splitv_input, func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(split_v); |
|
|
|
split_v->set_abstract(split_nodes[i]->abstract()); |
|
|
|
AnfAlgo::CopyNodeAttrs(split_nodes[i], split_v); |
|
|
|
|
|
|
|
@@ -223,9 +264,10 @@ void DynamicRnnGradFissionV2::CreateTLoopNodeWithEdge(const FuncGraphPtr &func_g |
|
|
|
|
|
|
|
AnfNodePtr DynamicRnnGradFissionV2::AddLSTMInputGradNode(const FuncGraphPtr &func_graph, |
|
|
|
const CNodePtr &dynamic_rnn_grad_cnode, |
|
|
|
const RNNShapeSpecs &specs, |
|
|
|
std::vector<AnfNodePtr> *outputs) const { |
|
|
|
std::vector<std::vector<AnfNodePtr>> result_nodes; |
|
|
|
CreateTLoopNode(func_graph, dynamic_rnn_grad_cnode, &result_nodes); |
|
|
|
CreateTLoopNode(func_graph, dynamic_rnn_grad_cnode, specs, &result_nodes); |
|
|
|
|
|
|
|
auto origin_input5_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex6), 0); |
|
|
|
std::vector<size_t> split_c_dims{IntToSize(1), origin_input5_shape[0], origin_input5_shape[1]}; |
|
|
|
@@ -290,7 +332,8 @@ AnfNodePtr DynamicRnnGradFissionV2::AddLSTMInputGradNode(const FuncGraphPtr &fun |
|
|
|
|
|
|
|
// Add edges |
|
|
|
std::vector<std::vector<AnfNodePtr>> loop_node_outputs; |
|
|
|
CreateTLoopNodeWithEdge(func_graph, dynamic_rnn_grad_cnode, result_nodes, num_split_x, &loop_node_outputs); |
|
|
|
CreateTLoopNodeWithEdge(func_graph, dynamic_rnn_grad_cnode, result_nodes, num_split_x, specs.shape_need_align, |
|
|
|
&loop_node_outputs); |
|
|
|
auto &pre_basic_lstm_cell_c_state_grad_outputs = loop_node_outputs[kIndex0]; |
|
|
|
auto &pre_split_outputs = loop_node_outputs[kIndex1]; |
|
|
|
auto &lstm_x_concat_input = loop_node_outputs[kIndex2]; |
|
|
|
@@ -306,10 +349,8 @@ AnfNodePtr DynamicRnnGradFissionV2::AddLSTMInputGradNode(const FuncGraphPtr &fun |
|
|
|
|
|
|
|
// Create lstm_gage_concat |
|
|
|
auto lstm_gage_concat = NewCNode(lstm_gage_concat_input, func_graph); |
|
|
|
auto origin_input7_shape = AnfAlgo::GetOutputInferShape(origin_input7, 0); |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape( |
|
|
|
{kNumberTypeFloat16}, |
|
|
|
{{origin_input7_shape[kDim0], origin_input7_shape[kDim1], kDimMultiNum * origin_input7_shape[kDim2]}}, |
|
|
|
{kNumberTypeFloat16}, {{specs.t_size, specs.batch_size, kDimMultiNum * specs.hidden_nz_size * kCubeSize}}, |
|
|
|
lstm_gage_concat.get()); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(num_split_x)), lstm_gage_concat); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(num_split_x)}), lstm_gage_concat); |
|
|
|
@@ -484,37 +525,69 @@ AnfNodePtr DynamicRnnGradFissionV2::CreateBatchMatMul2(const FuncGraphPtr &func_ |
|
|
|
return batch_matmul; |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr DynamicRnnGradFissionV2::CreateTranspose(const FuncGraphPtr &func_graph, const AnfNodePtr &dw_reduce_sum, |
|
|
|
const RNNShapeSpecs &specs) const { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
std::vector<AnfNodePtr> transpose_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimTranspose->name())), |
|
|
|
dw_reduce_sum}; |
|
|
|
auto transpose = NewCNode(transpose_inputs, func_graph); |
|
|
|
std::vector<size_t> out_shape = {specs.input_size + specs.hidden_size, kDimMultiNum * specs.hidden_size}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dw_reduce_sum, 0)}, {out_shape}, |
|
|
|
transpose.get()); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int64_t>{1, 0, 2, 3}), transpose); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrInputSize, MakeValue(SizeToLong(specs.input_size)), transpose); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrHiddenSize, MakeValue(SizeToLong(specs.hidden_size)), transpose); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrFixedInputFormat, MakeValue(std::vector<string>{kOpFormat_FRAC_NZ}), transpose); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrFixedOutputFormat, MakeValue(std::vector<string>{kOpFormat_FRACTAL_ZN_RNN}), transpose); |
|
|
|
return transpose; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr DynamicRnnGradFissionV2::CreateDwReduceSum(const FuncGraphPtr &func_graph, |
|
|
|
const CNodePtr &dynamic_rnn_grad_cnode, |
|
|
|
const AnfNodePtr &batch_matmul) const { |
|
|
|
const AnfNodePtr &batch_matmul, |
|
|
|
const RNNShapeSpecs &specs) const { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
// Create node |
|
|
|
std::vector<AnfNodePtr> reduce_sum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())), |
|
|
|
batch_matmul}; |
|
|
|
auto reduce_sum = NewCNode(reduce_sum_inputs, func_graph); |
|
|
|
// Set infer data type and shape |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)}, |
|
|
|
{AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 0)}, reduce_sum.get()); |
|
|
|
std::vector<size_t> out_shape = {specs.input_size + specs.hidden_size, |
|
|
|
kDimMultiNum * specs.hidden_nz_size * kCubeSize}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)}, {out_shape}, |
|
|
|
reduce_sum.get()); |
|
|
|
// Set attr |
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0}), reduce_sum); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(false), reduce_sum); |
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum); |
|
|
|
return reduce_sum; |
|
|
|
|
|
|
|
auto ret_node = reduce_sum; |
|
|
|
if (specs.shape_need_align) { |
|
|
|
ret_node = CreateTranspose(func_graph, reduce_sum, specs); |
|
|
|
} |
|
|
|
return ret_node; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr DynamicRnnGradFissionV2::CreateDwReshape(const FuncGraphPtr &func_graph, |
|
|
|
const CNodePtr &dynamic_rnn_grad_cnode, |
|
|
|
const AnfNodePtr &batch_matmul) const { |
|
|
|
const AnfNodePtr &batch_matmul, const RNNShapeSpecs &specs) const { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
// Create node |
|
|
|
std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())), |
|
|
|
batch_matmul}; |
|
|
|
auto reshape = NewCNode(reshape_inputs, func_graph); |
|
|
|
// Set infer data type and shape |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)}, |
|
|
|
{AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 0)}, reshape.get()); |
|
|
|
std::vector<size_t> out_shape = {specs.input_size + specs.hidden_size, |
|
|
|
kDimMultiNum * specs.hidden_nz_size * kCubeSize}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)}, {out_shape}, |
|
|
|
reshape.get()); |
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reshape); |
|
|
|
return reshape; |
|
|
|
|
|
|
|
auto ret_node = reshape; |
|
|
|
if (specs.shape_need_align) { |
|
|
|
ret_node = CreateTranspose(func_graph, reshape, specs); |
|
|
|
} |
|
|
|
return ret_node; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr DynamicRnnGradFissionV2::CreateValueNode(const FuncGraphPtr &func_graph, |
|
|
|
@@ -537,8 +610,8 @@ AnfNodePtr DynamicRnnGradFissionV2::CreateValueNode(const FuncGraphPtr &func_gra |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr DynamicRnnGradFissionV2::CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &, |
|
|
|
const AnfNodePtr &lstm_input_grad, |
|
|
|
const AnfNodePtr &value_node) const { |
|
|
|
const AnfNodePtr &lstm_input_grad, const AnfNodePtr &value_node, |
|
|
|
const RNNShapeSpecs &specs) const { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
// Create node |
|
|
|
auto batch_matmul = CreateBatchMatMul2(func_graph, lstm_input_grad, value_node); |
|
|
|
@@ -546,12 +619,18 @@ AnfNodePtr DynamicRnnGradFissionV2::CreateDbReduceSum(const FuncGraphPtr &func_g |
|
|
|
batch_matmul}; |
|
|
|
auto reduce_sum = NewCNode(reduce_sum_inputs, func_graph); |
|
|
|
// Set infer data type and shape |
|
|
|
auto out_shape = {AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[kDim2]}; |
|
|
|
std::vector<size_t> out_shape = {kDimMultiNum * specs.hidden_size}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {out_shape}, reduce_sum.get()); |
|
|
|
// Set attr |
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0}), reduce_sum); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(false), reduce_sum); |
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum); |
|
|
|
if (specs.shape_need_align) { |
|
|
|
AnfAlgo::SetNodeAttr(kAttrInputSize, MakeValue(SizeToLong(specs.input_size)), reduce_sum); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrHiddenSize, MakeValue(SizeToLong(specs.hidden_size)), reduce_sum); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrFixedInputFormat, MakeValue(std::vector<string>{kOpFormat_DEFAULT}), reduce_sum); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrFixedOutputFormat, MakeValue(std::vector<string>{kOpFormat_ND_RNN_BIAS}), reduce_sum); |
|
|
|
} |
|
|
|
return reduce_sum; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -572,20 +651,28 @@ const AnfNodePtr DynamicRnnGradFissionV2::Process(const FuncGraphPtr &func_graph |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (AnfAlgo::IsDynamicShape(node)) { |
|
|
|
MS_LOG(INFO) << "DynamicRnnGrad is dynamic shape, can not do fission."; |
|
|
|
MS_LOG(INFO) << "DynamicRNNGrad is dynamic shape, can not do fission."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
std::vector<AnfNodePtr> new_outputs; |
|
|
|
auto lstm_input_grad = AddLSTMInputGradNode(func_graph, dynamic_rnn_grad_cnode, &new_outputs); |
|
|
|
|
|
|
|
size_t t_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex7), 0)[0]; |
|
|
|
size_t hidden_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex7), 0)[kDim2]; |
|
|
|
if (hidden_size % kCubeSize != 0) { |
|
|
|
MS_LOG(EXCEPTION) << "`hidden_size` in this node should be multiple of 16, but got " << hidden_size << ". " |
|
|
|
<< dynamic_rnn_grad_cnode->DebugString(); |
|
|
|
auto input0_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex1), 0); |
|
|
|
RNNShapeSpecs specs; |
|
|
|
specs.t_size = input0_shape[0]; |
|
|
|
specs.batch_size = input0_shape[1]; |
|
|
|
specs.input_size = input0_shape[kDim2]; |
|
|
|
specs.hidden_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex7), 0)[kDim2]; |
|
|
|
if (specs.hidden_size % kCubeSize != 0) { |
|
|
|
specs.shape_need_align = true; |
|
|
|
SetAttrInputAndHiddenSize(func_graph, dynamic_rnn_grad_cnode, SizeToLong(specs.input_size), |
|
|
|
SizeToLong(specs.hidden_size)); |
|
|
|
} |
|
|
|
specs.batch_nz_size = (specs.batch_size + kCubeSize - 1) / kCubeSize; |
|
|
|
specs.input_nz_size = (specs.input_size + kCubeSize - 1) / kCubeSize; |
|
|
|
specs.hidden_nz_size = (specs.hidden_size + kCubeSize - 1) / kCubeSize; |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> new_outputs; |
|
|
|
auto lstm_input_grad = AddLSTMInputGradNode(func_graph, dynamic_rnn_grad_cnode, specs, &new_outputs); |
|
|
|
AnfNodePtr concat = nullptr; |
|
|
|
if (t_size != 1) { |
|
|
|
if (specs.t_size != 1) { |
|
|
|
auto splitv = CreateSplitV(func_graph, dynamic_rnn_grad_cnode); |
|
|
|
auto h_concat = CreateHConcat(func_graph, dynamic_rnn_grad_cnode, splitv); |
|
|
|
concat = CreateConcat(func_graph, dynamic_rnn_grad_cnode, h_concat); |
|
|
|
@@ -595,17 +682,17 @@ const AnfNodePtr DynamicRnnGradFissionV2::Process(const FuncGraphPtr &func_graph |
|
|
|
|
|
|
|
auto batch_matmul = CreateBatchMatMul(func_graph, lstm_input_grad, concat); |
|
|
|
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; |
|
|
|
if (t_size != 1) { |
|
|
|
auto dw_reduce_sum = CreateDwReduceSum(func_graph, dynamic_rnn_grad_cnode, batch_matmul); |
|
|
|
if (specs.t_size != 1) { |
|
|
|
auto dw_reduce_sum = CreateDwReduceSum(func_graph, dynamic_rnn_grad_cnode, batch_matmul, specs); |
|
|
|
make_tuple_inputs.emplace_back(dw_reduce_sum); |
|
|
|
} else { |
|
|
|
auto dw_reshape = CreateDwReshape(func_graph, dynamic_rnn_grad_cnode, batch_matmul); |
|
|
|
auto dw_reshape = CreateDwReshape(func_graph, dynamic_rnn_grad_cnode, batch_matmul, specs); |
|
|
|
make_tuple_inputs.emplace_back(dw_reshape); |
|
|
|
} |
|
|
|
|
|
|
|
auto value_node = CreateValueNode(func_graph, dynamic_rnn_grad_cnode); |
|
|
|
// create reduce_sum_2 |
|
|
|
auto db_reduce_sum = CreateDbReduceSum(func_graph, dynamic_rnn_grad_cnode, lstm_input_grad, value_node); |
|
|
|
auto db_reduce_sum = CreateDbReduceSum(func_graph, dynamic_rnn_grad_cnode, lstm_input_grad, value_node, specs); |
|
|
|
make_tuple_inputs.emplace_back(db_reduce_sum); |
|
|
|
make_tuple_inputs.insert(make_tuple_inputs.end(), new_outputs.begin(), new_outputs.end()); |
|
|
|
auto make_tuple = func_graph->NewCNode(make_tuple_inputs); |
|
|
|
|