|
|
@@ -16,8 +16,10 @@ |
|
|
#include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.h" |
|
|
#include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.h" |
|
|
#include <vector> |
|
|
#include <vector> |
|
|
#include <memory> |
|
|
#include <memory> |
|
|
|
|
|
#include "backend/session/kernel_graph.h" |
|
|
#include "backend/session/anf_runtime_algorithm.h" |
|
|
#include "backend/session/anf_runtime_algorithm.h" |
|
|
#include "utils/trace_base.h" |
|
|
#include "utils/trace_base.h" |
|
|
|
|
|
#include "utils/tensor_construct_utils.h" |
|
|
|
|
|
|
|
|
namespace mindspore { |
|
|
namespace mindspore { |
|
|
namespace opt { |
|
|
namespace opt { |
|
|
@@ -46,7 +48,7 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn |
|
|
|
|
|
|
|
|
std::vector<size_t> output0_dims{origin_input9_shape[0], 4 * (((origin_input9_shape[1] + 15) / 16) * 16)}; |
|
|
std::vector<size_t> output0_dims{origin_input9_shape[0], 4 * (((origin_input9_shape[1] + 15) / 16) * 16)}; |
|
|
std::vector<size_t> output1_dims{input_i_shape[1], input_i_shape[2]}; |
|
|
std::vector<size_t> output1_dims{input_i_shape[1], input_i_shape[2]}; |
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32}, {output0_dims, output1_dims}, |
|
|
|
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16, kNumberTypeFloat32}, {output0_dims, output1_dims}, |
|
|
basic_lstm_cell_c_state_grad.get()); |
|
|
basic_lstm_cell_c_state_grad.get()); |
|
|
AnfAlgo::SetNodeAttr("forget_bias", MakeValue(1.0f), basic_lstm_cell_c_state_grad); |
|
|
AnfAlgo::SetNodeAttr("forget_bias", MakeValue(1.0f), basic_lstm_cell_c_state_grad); |
|
|
AnfAlgo::SetNodeAttr("activation", MakeValue("Tanh"), basic_lstm_cell_c_state_grad); |
|
|
AnfAlgo::SetNodeAttr("activation", MakeValue("Tanh"), basic_lstm_cell_c_state_grad); |
|
|
@@ -260,7 +262,7 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr & |
|
|
// Create lstm_gage_concat |
|
|
// Create lstm_gage_concat |
|
|
auto lstm_gage_concat = func_graph->NewCNode(lstm_gage_concat_input); |
|
|
auto lstm_gage_concat = func_graph->NewCNode(lstm_gage_concat_input); |
|
|
auto origin_input7_shape = AnfAlgo::GetOutputInferShape(origin_input7, 0); |
|
|
auto origin_input7_shape = AnfAlgo::GetOutputInferShape(origin_input7, 0); |
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, |
|
|
|
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, |
|
|
{{origin_input7_shape[0], origin_input7_shape[1], 4 * origin_input7_shape[2]}}, |
|
|
{{origin_input7_shape[0], origin_input7_shape[1], 4 * origin_input7_shape[2]}}, |
|
|
lstm_gage_concat.get()); |
|
|
lstm_gage_concat.get()); |
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(num_split_x)), lstm_gage_concat); |
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(num_split_x)), lstm_gage_concat); |
|
|
@@ -413,6 +415,24 @@ AnfNodePtr CreateBatchMatMul(const FuncGraphPtr &func_graph, const AnfNodePtr &l |
|
|
return batch_matmul; |
|
|
return batch_matmul; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
AnfNodePtr CreateBatchMatMul2(const FuncGraphPtr &func_graph, const AnfNodePtr &lstm_input_grad, |
|
|
|
|
|
const AnfNodePtr &node) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
|
|
// Create node |
|
|
|
|
|
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMul->name())), |
|
|
|
|
|
node, lstm_input_grad}; |
|
|
|
|
|
auto batch_matmul = func_graph->NewCNode(matmul_inputs); |
|
|
|
|
|
// Set infer data type and shape |
|
|
|
|
|
auto out_shape = {AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[0], IntToSize(1), |
|
|
|
|
|
AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[2]}; |
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {out_shape}, batch_matmul.get()); |
|
|
|
|
|
// Set attr |
|
|
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), batch_matmul); |
|
|
|
|
|
AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(false), batch_matmul); |
|
|
|
|
|
AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(false), batch_matmul); |
|
|
|
|
|
return batch_matmul; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
AnfNodePtr CreateDwReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode, |
|
|
AnfNodePtr CreateDwReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode, |
|
|
const AnfNodePtr &batch_matmul) { |
|
|
const AnfNodePtr &batch_matmul) { |
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
@@ -430,18 +450,38 @@ AnfNodePtr CreateDwReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dyn |
|
|
return reduce_sum; |
|
|
return reduce_sum; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
AnfNodePtr CreateValueNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) { |
|
|
|
|
|
auto origin_input7 = dynamic_rnn_grad_cnode->input(8); |
|
|
|
|
|
auto origin_input7_shape = AnfAlgo::GetOutputInferShape(origin_input7, 0); |
|
|
|
|
|
auto t_size = origin_input7_shape[0]; |
|
|
|
|
|
auto n_size = origin_input7_shape[1]; |
|
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> shape = {t_size, IntToSize(1), n_size}; |
|
|
|
|
|
std::vector<int64_t> output_shape = {SizeToLong(t_size), SizeToLong(1), SizeToLong(n_size)}; |
|
|
|
|
|
std::vector<int64_t> output_tensor = {(SizeToLong(n_size) + SizeToLong(15)) / SizeToLong(16) * SizeToLong(16) * |
|
|
|
|
|
SizeToLong(16) * SizeToLong(t_size)}; |
|
|
|
|
|
auto tensor = TensorConstructUtils::CreateOnesTensor(kNumberTypeFloat32, output_tensor); |
|
|
|
|
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, output_shape); |
|
|
|
|
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>(); |
|
|
|
|
|
auto value_node = kernel_graph->NewValueNode(x_abstract, tensor); |
|
|
|
|
|
kernel_graph->AddValueNodeToGraph(value_node); |
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {shape}, value_node.get()); |
|
|
|
|
|
return value_node; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
AnfNodePtr CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode, |
|
|
AnfNodePtr CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode, |
|
|
const AnfNodePtr &lstm_input_grad) { |
|
|
|
|
|
|
|
|
const AnfNodePtr &lstm_input_grad, const AnfNodePtr &value_node) { |
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
// Create node |
|
|
// Create node |
|
|
|
|
|
auto batch_matmul = CreateBatchMatMul2(func_graph, lstm_input_grad, value_node); |
|
|
std::vector<AnfNodePtr> reduce_sum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())), |
|
|
std::vector<AnfNodePtr> reduce_sum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())), |
|
|
lstm_input_grad}; |
|
|
|
|
|
|
|
|
batch_matmul}; |
|
|
auto reduce_sum = func_graph->NewCNode(reduce_sum_inputs); |
|
|
auto reduce_sum = func_graph->NewCNode(reduce_sum_inputs); |
|
|
// Set infer data type and shape |
|
|
// Set infer data type and shape |
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 1)}, |
|
|
|
|
|
{AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 1)}, reduce_sum.get()); |
|
|
|
|
|
|
|
|
auto out_shape = {AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[2]}; |
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {out_shape}, reduce_sum.get()); |
|
|
// Set attr |
|
|
// Set attr |
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0, 1}), reduce_sum); |
|
|
|
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0}), reduce_sum); |
|
|
AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(false), reduce_sum); |
|
|
AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(false), reduce_sum); |
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum); |
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum); |
|
|
return reduce_sum; |
|
|
return reduce_sum; |
|
|
@@ -486,8 +526,9 @@ const AnfNodePtr DynamicRnnGradFissionV2::Process(const FuncGraphPtr &func_graph |
|
|
make_tuple_inputs.emplace_back(batch_matmul); |
|
|
make_tuple_inputs.emplace_back(batch_matmul); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
auto value_node = CreateValueNode(func_graph, dynamic_rnn_grad_cnode); |
|
|
// create reduce_sum_2 |
|
|
// create reduce_sum_2 |
|
|
auto db_reduce_sum = CreateDbReduceSum(func_graph, dynamic_rnn_grad_cnode, lstm_input_grad); |
|
|
|
|
|
|
|
|
auto db_reduce_sum = CreateDbReduceSum(func_graph, dynamic_rnn_grad_cnode, lstm_input_grad, value_node); |
|
|
make_tuple_inputs.emplace_back(db_reduce_sum); |
|
|
make_tuple_inputs.emplace_back(db_reduce_sum); |
|
|
make_tuple_inputs.insert(make_tuple_inputs.end(), new_outputs.begin(), new_outputs.end()); |
|
|
make_tuple_inputs.insert(make_tuple_inputs.end(), new_outputs.begin(), new_outputs.end()); |
|
|
auto make_tuple = func_graph->NewCNode(make_tuple_inputs); |
|
|
auto make_tuple = func_graph->NewCNode(make_tuple_inputs); |
|
|
|