|
|
|
@@ -1,5 +1,5 @@ |
|
|
|
/** |
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd |
|
|
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd |
|
|
|
* |
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
* you may not use this file except in compliance with the License. |
|
|
|
@@ -15,6 +15,7 @@ |
|
|
|
*/ |
|
|
|
#include "backend/optimizer/ascend/ir_fission/dynamic_gru_v2_grad_fission.h" |
|
|
|
#include <vector> |
|
|
|
#include <map> |
|
|
|
#include <memory> |
|
|
|
#include <string> |
|
|
|
#include <algorithm> |
|
|
|
@@ -26,70 +27,211 @@ namespace opt { |
|
|
|
namespace { |
|
|
|
constexpr size_t kDynamicGRUV2GradInputNum = 12; |
|
|
|
constexpr size_t kDynamicGRUV2GradOutputNum = 6; |
|
|
|
constexpr size_t kSplitVOutputNum = 2; |
|
|
|
constexpr size_t kGRUV2HiddenGradOutputNum = 3; |
|
|
|
constexpr size_t kConcatNum = 2; |
|
|
|
constexpr size_t kGRUV2HiddenGradCellOutputNum = 3; |
|
|
|
constexpr size_t kGateNum = 3; |
|
|
|
constexpr size_t k3Dims = 3; |
|
|
|
constexpr size_t kConcatNum = 2; |
|
|
|
constexpr size_t kSplitVOutputNum = 2; |
|
|
|
size_t t_size = 0; |
|
|
|
size_t batch_size = 0; |
|
|
|
size_t hidden_size = 0; |
|
|
|
size_t input_size = 0; |
|
|
|
TypeId dh_dtype = kNumberTypeFloat32; |
|
|
|
|
|
|
|
AnfNodePtr CreateGRUV2HiddenGradNode(const FuncGraphPtr &graph, const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
const auto &dynamic_gru_v2_grad_inputs = cnode->inputs(); |
|
|
|
std::vector<AnfNodePtr> gru_v2_hidden_grad_inputs = { |
|
|
|
NewValueNode(std::make_shared<Primitive>(kGRUV2HiddenGradOpName)), |
|
|
|
dynamic_gru_v2_grad_inputs[kIndex3], |
|
|
|
dynamic_gru_v2_grad_inputs[kIndex5], |
|
|
|
dynamic_gru_v2_grad_inputs[kIndex6], |
|
|
|
dynamic_gru_v2_grad_inputs[kIndex7], |
|
|
|
dynamic_gru_v2_grad_inputs[kIndex8], |
|
|
|
dynamic_gru_v2_grad_inputs[kIndex9], |
|
|
|
dynamic_gru_v2_grad_inputs[kIndex10], |
|
|
|
dynamic_gru_v2_grad_inputs[kIndex11], |
|
|
|
dynamic_gru_v2_grad_inputs[kIndex12]}; |
|
|
|
std::map<std::string, size_t> input_index = { |
|
|
|
{"x", kIndex1}, {"weight_input", kIndex2}, {"weight_hidden", kIndex3}, |
|
|
|
{"y", kIndex4}, {"init_h", kIndex5}, {"h", kIndex6}, |
|
|
|
{"dy", kIndex7}, {"dh", kIndex8}, {"update", kIndex9}, |
|
|
|
{"reset", kIndex10}, {"new", kIndex11}, {"hidden_new", kIndex12}, |
|
|
|
{"seq_length", kIndex13}, {"mask", kIndex14}}; |
|
|
|
|
|
|
|
std::map<std::string, size_t> output_index = {{"dw_input", kIndex0}, {"dw_hidden", kIndex1}, {"db_input", kIndex2}, |
|
|
|
{"db_hidden", kIndex3}, {"dx", kIndex4}, {"dh_prev", kIndex5}}; |
|
|
|
|
|
|
|
std::map<std::string, size_t> hidden_grad_input_index = { |
|
|
|
{"dh_pre_t", kIndex1}, {"h", kIndex2}, {"dy", kIndex3}, {"dh", kIndex4}, |
|
|
|
{"update", kIndex5}, {"reset", kIndex6}, {"new", kIndex7}, {"hidden_new", kIndex8}}; |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> ori_outputs; |
|
|
|
CreateMultipleOutputsOfAnfNode(graph, node, kDynamicGRUV2GradOutputNum, &ori_outputs); |
|
|
|
auto gru_v2_hidden_grad_op = graph->NewCNode(gru_v2_hidden_grad_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(gru_v2_hidden_grad_op); |
|
|
|
auto h_dtype = AnfAlgo::GetOutputInferDataType(dynamic_gru_v2_grad_inputs[kIndex6], 0); |
|
|
|
auto types = {h_dtype, h_dtype, h_dtype}; |
|
|
|
std::vector<size_t> dh_preh_shape = AnfAlgo::GetOutputInferShape(ori_outputs[kIndex5], 0); |
|
|
|
std::vector<size_t> dgate_h_shape = { |
|
|
|
AnfAlgo::GetOutputInferShape(dynamic_gru_v2_grad_inputs[kIndex6], 0)[kDim0], |
|
|
|
AnfAlgo::GetOutputInferShape(dynamic_gru_v2_grad_inputs[kIndex6], 0)[kDim1], |
|
|
|
kGateNum * AnfAlgo::GetOutputInferShape(dynamic_gru_v2_grad_inputs[kIndex6], 0)[kDim2]}; |
|
|
|
std::vector<size_t> dnx_t_shape = AnfAlgo::GetOutputInferShape(dynamic_gru_v2_grad_inputs[kIndex6], 0); |
|
|
|
auto shapes = {dh_preh_shape, dgate_h_shape, dnx_t_shape}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, gru_v2_hidden_grad_op.get()); |
|
|
|
auto gate_order = AnfAlgo::GetNodeAttr<std::string>(cnode, "gate_order"); |
|
|
|
AnfAlgo::SetNodeAttr("gate_order", MakeValue(gate_order), gru_v2_hidden_grad_op); |
|
|
|
return gru_v2_hidden_grad_op; |
|
|
|
std::map<std::string, size_t> hidden_grad_output_index = { |
|
|
|
{"dh_prev", kIndex0}, {"dgate_h", kIndex1}, {"dnt_x", kIndex2}}; |
|
|
|
|
|
|
|
AnfNodePtr CreateGRUV2HiddenGradCellNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_gru_v2_grad_cnode, |
|
|
|
const AnfNodePtr &last_gru_hidden_grad_node, |
|
|
|
const AnfNodePtr &last_matmul_node, const std::string &gate_order, |
|
|
|
const size_t cur_t) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(dynamic_gru_v2_grad_cnode); |
|
|
|
const auto &dynamic_gru_v2_grad_inputs = dynamic_gru_v2_grad_cnode->inputs(); |
|
|
|
std::vector<AnfNodePtr> gru_v2_hidden_grad_cell_inputs = { |
|
|
|
NewValueNode(std::make_shared<Primitive>(kGRUV2HiddenGradCellOpName))}; |
|
|
|
std::vector<AnfNodePtr> dynamic_gru_grad_outputs; |
|
|
|
CreateMultipleOutputsOfAnfNode(func_graph, dynamic_gru_v2_grad_cnode, kDynamicGRUV2GradOutputNum, |
|
|
|
&dynamic_gru_grad_outputs); |
|
|
|
if (cur_t == 0) { |
|
|
|
gru_v2_hidden_grad_cell_inputs.emplace_back(dynamic_gru_v2_grad_inputs[input_index["dh"]]); |
|
|
|
} else { |
|
|
|
MS_EXCEPTION_IF_NULL(last_gru_hidden_grad_node); |
|
|
|
std::vector<AnfNodePtr> last_gru_hidden_grad_outputs; |
|
|
|
CreateMultipleOutputsOfAnfNode(func_graph, last_gru_hidden_grad_node->cast<CNodePtr>(), |
|
|
|
kGRUV2HiddenGradCellOutputNum, &last_gru_hidden_grad_outputs); |
|
|
|
gru_v2_hidden_grad_cell_inputs.emplace_back(last_gru_hidden_grad_outputs[hidden_grad_output_index["dh_prev"]]); |
|
|
|
} |
|
|
|
if (cur_t < t_size - 1) { |
|
|
|
gru_v2_hidden_grad_cell_inputs.emplace_back(dynamic_gru_v2_grad_inputs[input_index["h"]]); |
|
|
|
} else { |
|
|
|
gru_v2_hidden_grad_cell_inputs.emplace_back(dynamic_gru_v2_grad_inputs[input_index["init_h"]]); |
|
|
|
} |
|
|
|
gru_v2_hidden_grad_cell_inputs.emplace_back(dynamic_gru_v2_grad_inputs[input_index["dy"]]); |
|
|
|
auto input_dh = dynamic_gru_v2_grad_inputs[input_index["dh"]]; |
|
|
|
dh_dtype = AnfAlgo::GetOutputInferDataType(input_dh, 0); |
|
|
|
if (cur_t == 0) { |
|
|
|
gru_v2_hidden_grad_cell_inputs.emplace_back(input_dh); |
|
|
|
} else { |
|
|
|
MS_EXCEPTION_IF_NULL(last_matmul_node); |
|
|
|
gru_v2_hidden_grad_cell_inputs.emplace_back(last_matmul_node); |
|
|
|
} |
|
|
|
gru_v2_hidden_grad_cell_inputs.emplace_back(dynamic_gru_v2_grad_inputs[input_index["update"]]); |
|
|
|
gru_v2_hidden_grad_cell_inputs.emplace_back(dynamic_gru_v2_grad_inputs[input_index["reset"]]); |
|
|
|
gru_v2_hidden_grad_cell_inputs.emplace_back(dynamic_gru_v2_grad_inputs[input_index["new"]]); |
|
|
|
gru_v2_hidden_grad_cell_inputs.emplace_back(dynamic_gru_v2_grad_inputs[input_index["hidden_new"]]); |
|
|
|
auto gru_v2_hidden_grad_cell_op = func_graph->NewCNode(gru_v2_hidden_grad_cell_inputs); |
|
|
|
|
|
|
|
std::vector<size_t> dh_prev_shape = |
|
|
|
AnfAlgo::GetOutputInferShape(dynamic_gru_grad_outputs[output_index["dh_prev"]], 0); |
|
|
|
std::vector<size_t> dgate_h_shape = {1, batch_size, kGateNum * hidden_size}; |
|
|
|
std::vector<size_t> dnt_x_shape = {1, batch_size, hidden_size}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({dh_dtype, dh_dtype, dh_dtype}, {dh_prev_shape, dgate_h_shape, dnt_x_shape}, |
|
|
|
gru_v2_hidden_grad_cell_op.get()); |
|
|
|
AnfAlgo::SetNodeAttr("t_state", MakeValue(SizeToLong(cur_t)), gru_v2_hidden_grad_cell_op); |
|
|
|
AnfAlgo::SetNodeAttr("gate_order", MakeValue(gate_order), gru_v2_hidden_grad_cell_op); |
|
|
|
return gru_v2_hidden_grad_cell_op; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr CreateHSplitVDNode(const FuncGraphPtr &graph, const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
// SplitV |
|
|
|
std::vector<AnfNodePtr> splitvd_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())), node}; |
|
|
|
auto split_vd = graph->NewCNode(splitvd_input); |
|
|
|
MS_EXCEPTION_IF_NULL(split_vd); |
|
|
|
auto dtypes = {AnfAlgo::GetOutputInferDataType(node, 0), AnfAlgo::GetOutputInferDataType(node, 0)}; |
|
|
|
size_t t_size = AnfAlgo::GetOutputInferShape(node, 0)[kDim0]; |
|
|
|
size_t batch = AnfAlgo::GetOutputInferShape(node, 0)[kDim1]; |
|
|
|
size_t hidden_size = AnfAlgo::GetOutputInferShape(node, 0)[kDim2]; |
|
|
|
std::vector<size_t> shape = {t_size - IntToSize(1), batch, hidden_size}; |
|
|
|
std::vector<size_t> shape2 = {IntToSize(1), batch, hidden_size}; |
|
|
|
std::vector<std::vector<size_t>> shapes = {shape, shape2}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_vd.get()); |
|
|
|
AnfAlgo::SetNodeAttr("split_dim", MakeValue(SizeToLong(kDim0)), split_vd); |
|
|
|
AnfAlgo::SetNodeAttr("num_split", MakeValue(SizeToLong(kSplitVOutputNum)), split_vd); |
|
|
|
std::vector<int64_t> size_splits = {SizeToLong(t_size - 1), SizeToLong(1)}; |
|
|
|
AnfAlgo::SetNodeAttr("size_splits", MakeValue(size_splits), split_vd); |
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), split_vd); |
|
|
|
return split_vd; |
|
|
|
void AddTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_gru_v2_grad_cnode, |
|
|
|
std::vector<std::vector<AnfNodePtr>> *result_nodes) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(dynamic_gru_v2_grad_cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(result_nodes); |
|
|
|
std::string gate_order = "rzh"; |
|
|
|
if (AnfAlgo::HasNodeAttr("gate_order", dynamic_gru_v2_grad_cnode)) { |
|
|
|
gate_order = AnfAlgo::GetNodeAttr<std::string>(dynamic_gru_v2_grad_cnode, "gate_order"); |
|
|
|
} |
|
|
|
std::vector<AnfNodePtr> gru_hidden_grad_cells; |
|
|
|
std::vector<AnfNodePtr> matmul_nodes; |
|
|
|
AnfNodePtr last_hidden_grad_node = nullptr; |
|
|
|
AnfNodePtr last_matmul_node = nullptr; |
|
|
|
const auto &dynamic_gru_v2_grad_inputs = dynamic_gru_v2_grad_cnode->inputs(); |
|
|
|
for (size_t i = 0; i < t_size; ++i) { |
|
|
|
// Create gru_hidden_grad_cell |
|
|
|
auto gru_hidden_grad_cell_node = CreateGRUV2HiddenGradCellNode( |
|
|
|
func_graph, dynamic_gru_v2_grad_cnode, last_hidden_grad_node, last_matmul_node, gate_order, i); |
|
|
|
// add matmul node |
|
|
|
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(kBatchMatMulOpName))}; |
|
|
|
auto gru_hidden_grad_cnode = gru_hidden_grad_cell_node->cast<CNodePtr>(); |
|
|
|
std::vector<AnfNodePtr> hidden_grad_outputs; |
|
|
|
CreateMultipleOutputsOfAnfNode(func_graph, gru_hidden_grad_cnode, kGRUV2HiddenGradCellOutputNum, |
|
|
|
&hidden_grad_outputs); |
|
|
|
auto dgate_h = hidden_grad_outputs[hidden_grad_output_index["dgate_h"]]; |
|
|
|
matmul_inputs.emplace_back(dgate_h); |
|
|
|
auto weight_hidden = dynamic_gru_v2_grad_inputs[input_index["weight_hidden"]]; |
|
|
|
std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())), |
|
|
|
weight_hidden}; |
|
|
|
auto reshape = func_graph->NewCNode(reshape_inputs); |
|
|
|
auto reshape_out_shape = {IntToSize(1), AnfAlgo::GetOutputInferShape(weight_hidden, 0)[0], |
|
|
|
AnfAlgo::GetOutputInferShape(weight_hidden, 0)[1]}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({dh_dtype}, {reshape_out_shape}, reshape.get()); |
|
|
|
matmul_inputs.emplace_back(reshape); |
|
|
|
auto matmul_node = func_graph->NewCNode(matmul_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(matmul_node); |
|
|
|
std::vector<size_t> out_shape = {1, batch_size, hidden_size}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({dh_dtype}, {out_shape}, matmul_node.get()); |
|
|
|
AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(false), matmul_node); |
|
|
|
AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(true), matmul_node); |
|
|
|
|
|
|
|
last_hidden_grad_node = gru_hidden_grad_cell_node; |
|
|
|
last_matmul_node = matmul_node; |
|
|
|
gru_hidden_grad_cells.emplace_back(gru_hidden_grad_cell_node); |
|
|
|
matmul_nodes.emplace_back(matmul_node); |
|
|
|
} |
|
|
|
// Add last GRUV2HiddenGradCell node |
|
|
|
auto gru_hidden_grad_cell_node = CreateGRUV2HiddenGradCellNode( |
|
|
|
func_graph, dynamic_gru_v2_grad_cnode, last_hidden_grad_node, last_matmul_node, gate_order, t_size); |
|
|
|
gru_hidden_grad_cells.emplace_back(gru_hidden_grad_cell_node); |
|
|
|
result_nodes->emplace_back(gru_hidden_grad_cells); |
|
|
|
result_nodes->emplace_back(matmul_nodes); |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr AddTConcatNode(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &gru_hidden_grad_nodes, |
|
|
|
size_t concat_output_index) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))}; |
|
|
|
for (size_t i = 0; i < t_size; i++) { |
|
|
|
auto gru_hidden_grad_node_i = gru_hidden_grad_nodes[t_size - 1 - i]; |
|
|
|
MS_EXCEPTION_IF_NULL(gru_hidden_grad_node_i); |
|
|
|
std::vector<AnfNodePtr> gru_hidden_grad_node_outputs; |
|
|
|
CreateMultipleOutputsOfAnfNode(func_graph, gru_hidden_grad_node_i, kGRUV2HiddenGradCellOutputNum, |
|
|
|
&gru_hidden_grad_node_outputs); |
|
|
|
concat_inputs.emplace_back(gru_hidden_grad_node_outputs[concat_output_index]); |
|
|
|
} |
|
|
|
auto concat_t_node = func_graph->NewCNode(concat_inputs); |
|
|
|
auto out_dims = AnfAlgo::GetOutputInferShape(gru_hidden_grad_nodes[kIndex0], concat_output_index); |
|
|
|
std::vector<size_t> concat_output_shape = {t_size, out_dims[kDim1], out_dims[kDim2]}; |
|
|
|
auto out_type = AnfAlgo::GetOutputInferDataType(gru_hidden_grad_nodes[kIndex0], concat_output_index); |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({out_type}, {concat_output_shape}, concat_t_node.get()); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(t_size)), concat_t_node); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(t_size)}), concat_t_node); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(static_cast<int64_t>(0)), concat_t_node); |
|
|
|
return concat_t_node; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> AddGRUHiddenGradNode(const FuncGraphPtr &func_graph, |
|
|
|
const CNodePtr &dynamic_gru_v2_grad_cnode) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(dynamic_gru_v2_grad_cnode); |
|
|
|
std::vector<AnfNodePtr> result; |
|
|
|
std::vector<std::vector<AnfNodePtr>> result_nodes; |
|
|
|
// add loop t hidden grad nodes; [[hidden_grad_nodes] [matmul_nodes]] |
|
|
|
AddTLoopNode(func_graph, dynamic_gru_v2_grad_cnode, &result_nodes); |
|
|
|
if (result_nodes.empty() || result_nodes[0].empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "result_node is empty, DynamicGRUGrad fission failed."; |
|
|
|
} |
|
|
|
auto gru_hidden_grad_nodes = result_nodes[kIndex0]; |
|
|
|
result.emplace_back(gru_hidden_grad_nodes[gru_hidden_grad_nodes.size() - 1]); |
|
|
|
if (t_size > 1) { |
|
|
|
// add dnt_x concat node [t_size, batch_size, hidden_size] |
|
|
|
auto dnt_x_concat_t_node = AddTConcatNode(func_graph, gru_hidden_grad_nodes, hidden_grad_output_index["dnt_x"]); |
|
|
|
// add dgate_h concat node [t_size, batch_size, 3 * hidden_size] |
|
|
|
auto dgate_h_concat_t_node = AddTConcatNode(func_graph, gru_hidden_grad_nodes, hidden_grad_output_index["dgate_h"]); |
|
|
|
result.emplace_back(dgate_h_concat_t_node); |
|
|
|
result.emplace_back(dnt_x_concat_t_node); |
|
|
|
} else { |
|
|
|
auto node = result_nodes[kIndex0][kIndex0]; |
|
|
|
result.emplace_back(node); |
|
|
|
result.emplace_back(node); |
|
|
|
} |
|
|
|
return result; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr AddHSplitNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_gru_v2_grad_cnode) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(dynamic_gru_v2_grad_cnode); |
|
|
|
auto input_h = dynamic_gru_v2_grad_cnode->input(input_index["h"]); |
|
|
|
std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())), |
|
|
|
input_h}; |
|
|
|
auto split_v = func_graph->NewCNode(splitv_input); |
|
|
|
// Set infer data type and shape |
|
|
|
auto dtypes = {AnfAlgo::GetOutputInferDataType(input_h, 0), AnfAlgo::GetOutputInferDataType(input_h, 0)}; |
|
|
|
std::vector<size_t> output1_shape = {t_size - 1, batch_size, hidden_size}; |
|
|
|
std::vector<size_t> output2_shape = {1, batch_size, hidden_size}; |
|
|
|
std::vector<int64_t> split_list = {SizeToLong(t_size - 1), 1}; |
|
|
|
std::vector<std::vector<size_t>> shapes = {output1_shape, output2_shape}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_v.get()); |
|
|
|
// Set attr |
|
|
|
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(SizeToLong(0)), split_v); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(SizeToLong(kSplitVOutputNum)), split_v); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(split_list), split_v); |
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), split_v); |
|
|
|
return split_v; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr CreateHReshape(const FuncGraphPtr &graph, const AnfNodePtr &node) { |
|
|
|
@@ -111,104 +253,110 @@ AnfNodePtr CreateHReshape(const FuncGraphPtr &graph, const AnfNodePtr &node) { |
|
|
|
return reshape; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr CreateHConcatDNode(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(node1); |
|
|
|
MS_EXCEPTION_IF_NULL(node2); |
|
|
|
std::vector<AnfNodePtr> ori_outputs; |
|
|
|
CreateMultipleOutputsOfAnfNode(graph, node2, kSplitVOutputNum, &ori_outputs); |
|
|
|
auto reshape = CreateHReshape(graph, node1); |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())), |
|
|
|
reshape, ori_outputs[kIndex0]}; |
|
|
|
auto concat_op = graph->NewCNode(concat_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(concat_op); |
|
|
|
|
|
|
|
std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node2, 0)[kDim0] + 1, |
|
|
|
AnfAlgo::GetOutputInferShape(node2, 0)[kDim1], |
|
|
|
AnfAlgo::GetOutputInferShape(node2, 0)[kDim2]}; |
|
|
|
auto types = {AnfAlgo::GetOutputInferDataType(node2, 0)}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(types, {shape}, concat_op.get()); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(kConcatNum)), concat_op); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{2}), concat_op); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(0)), concat_op); |
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat_op); |
|
|
|
return concat_op; |
|
|
|
AnfNodePtr AddHConcatNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_gru_v2_grad_cnode, |
|
|
|
const AnfNodePtr &splitv) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(dynamic_gru_v2_grad_cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(splitv); |
|
|
|
// Create node |
|
|
|
std::vector<AnfNodePtr> splitv_outputs; |
|
|
|
CreateMultipleOutputsOfAnfNode(func_graph, splitv, kSplitVOutputNum, &splitv_outputs); |
|
|
|
if (splitv_outputs.size() != kSplitVOutputNum) { |
|
|
|
MS_LOG(EXCEPTION) << "Create outputs of node " << splitv->DebugString() << " failed" |
|
|
|
<< " trace: " << trace::DumpSourceLines(splitv); |
|
|
|
} |
|
|
|
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))}; |
|
|
|
auto init_h_reshape = CreateHReshape(func_graph, dynamic_gru_v2_grad_cnode->input(input_index["init_h"])); |
|
|
|
concat_inputs.emplace_back(init_h_reshape); |
|
|
|
concat_inputs.emplace_back(splitv_outputs[kIndex0]); |
|
|
|
auto concat = func_graph->NewCNode(concat_inputs); |
|
|
|
// Set infer data type and shape |
|
|
|
std::vector<size_t> output_shape = {t_size, batch_size, hidden_size}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(init_h_reshape, 0)}, {output_shape}, |
|
|
|
concat.get()); |
|
|
|
// Set attr |
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(kConcatNum)), concat); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{kConcatNum}), concat); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(0)), concat); |
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat); |
|
|
|
return concat; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr CreateDgateHSplitVDNode(const FuncGraphPtr &graph, const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
AnfNodePtr AddDwhMatmulNode(const FuncGraphPtr &func_graph, const AnfNodePtr &dgate_h, const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(dgate_h); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
// SplitV |
|
|
|
std::vector<AnfNodePtr> splitvd_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())), node}; |
|
|
|
auto split_vd = graph->NewCNode(splitvd_input); |
|
|
|
MS_EXCEPTION_IF_NULL(split_vd); |
|
|
|
auto dtypes = {AnfAlgo::GetOutputInferDataType(node, 0), AnfAlgo::GetOutputInferDataType(node, 0)}; |
|
|
|
size_t t_size = AnfAlgo::GetOutputInferShape(node, 0)[kDim0]; |
|
|
|
size_t batch = AnfAlgo::GetOutputInferShape(node, 0)[kDim1]; |
|
|
|
size_t hidden_size = AnfAlgo::GetOutputInferShape(node, 0)[kDim2] / kGateNum; |
|
|
|
std::vector<size_t> shape = {t_size, batch, hidden_size << 1}; |
|
|
|
std::vector<size_t> shape2 = {t_size, batch, hidden_size}; |
|
|
|
// BatchMatMul |
|
|
|
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMul->name()))}; |
|
|
|
matmul_inputs.emplace_back(node); |
|
|
|
if (t_size == 1) { |
|
|
|
std::vector<AnfNodePtr> dgate_h_outputs; |
|
|
|
CreateMultipleOutputsOfAnfNode(func_graph, dgate_h, kGRUV2HiddenGradCellOutputNum, &dgate_h_outputs); |
|
|
|
matmul_inputs.emplace_back(dgate_h_outputs[hidden_grad_output_index["dgate_h"]]); |
|
|
|
} else { |
|
|
|
matmul_inputs.emplace_back(dgate_h); |
|
|
|
} |
|
|
|
auto batch_matmul = func_graph->NewCNode(matmul_inputs); |
|
|
|
std::vector<size_t> shape = {t_size, hidden_size, kGateNum * hidden_size}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {shape}, batch_matmul.get()); |
|
|
|
AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(true), batch_matmul); |
|
|
|
AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(false), batch_matmul); |
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), batch_matmul); |
|
|
|
return batch_matmul; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr CreateDgateHSplitVDNode(const FuncGraphPtr &func_graph, const AnfNodePtr &dgate_h) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(dgate_h); |
|
|
|
std::vector<AnfNodePtr> splitvd_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))}; |
|
|
|
if (t_size == 1) { |
|
|
|
std::vector<AnfNodePtr> dgate_h_outputs; |
|
|
|
CreateMultipleOutputsOfAnfNode(func_graph, dgate_h, kGRUV2HiddenGradCellOutputNum, &dgate_h_outputs); |
|
|
|
splitvd_input.emplace_back(dgate_h_outputs[hidden_grad_output_index["dgate_h"]]); |
|
|
|
} else { |
|
|
|
splitvd_input.emplace_back(dgate_h); |
|
|
|
} |
|
|
|
auto split_vd = func_graph->NewCNode(splitvd_input); |
|
|
|
auto dtypes = {AnfAlgo::GetOutputInferDataType(dgate_h, 0), AnfAlgo::GetOutputInferDataType(dgate_h, 0)}; |
|
|
|
std::vector<size_t> shape = {t_size, batch_size, hidden_size << 1}; |
|
|
|
std::vector<size_t> shape2 = {t_size, batch_size, hidden_size}; |
|
|
|
std::vector<std::vector<size_t>> shapes = {shape, shape2}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_vd.get()); |
|
|
|
AnfAlgo::SetNodeAttr("split_dim", MakeValue(SizeToLong(kDim2)), split_vd); |
|
|
|
AnfAlgo::SetNodeAttr("num_split", MakeValue(SizeToLong(kSplitVOutputNum)), split_vd); |
|
|
|
std::vector<int64_t> size_splits = {SizeToLong(hidden_size + hidden_size), SizeToLong(hidden_size)}; |
|
|
|
std::vector<int64_t> size_splits = {SizeToLong(hidden_size << 1), SizeToLong(hidden_size)}; |
|
|
|
AnfAlgo::SetNodeAttr("size_splits", MakeValue(size_splits), split_vd); |
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), split_vd); |
|
|
|
return split_vd; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr CreateDgateXConcatDNode(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
// node1: dgate_h_split |
|
|
|
// node2: dnt_x |
|
|
|
MS_EXCEPTION_IF_NULL(node1); |
|
|
|
MS_EXCEPTION_IF_NULL(node2); |
|
|
|
std::vector<AnfNodePtr> ori_outputs; |
|
|
|
CreateMultipleOutputsOfAnfNode(graph, node1, kSplitVOutputNum, &ori_outputs); |
|
|
|
|
|
|
|
// ConcatD |
|
|
|
AnfNodePtr CreateDgateXConcatDNode(const FuncGraphPtr &func_graph, const AnfNodePtr &split, const AnfNodePtr &dnt_x) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(split); |
|
|
|
MS_EXCEPTION_IF_NULL(dnt_x); |
|
|
|
std::vector<AnfNodePtr> split_outputs; |
|
|
|
CreateMultipleOutputsOfAnfNode(func_graph, split, kSplitVOutputNum, &split_outputs); |
|
|
|
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())), |
|
|
|
ori_outputs[kIndex0], node2}; |
|
|
|
auto concat_op = graph->NewCNode(concat_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(concat_op); |
|
|
|
std::vector<size_t> shape = { |
|
|
|
AnfAlgo::GetOutputInferShape(node2, 0)[kDim0], AnfAlgo::GetOutputInferShape(node2, 0)[kDim1], |
|
|
|
AnfAlgo::GetOutputInferShape(node1, 0)[kDim2] + AnfAlgo::GetOutputInferShape(node2, 0)[kDim2]}; |
|
|
|
auto types = {AnfAlgo::GetOutputInferDataType(node2, 0)}; |
|
|
|
split_outputs[kIndex0]}; |
|
|
|
if (t_size == 1) { |
|
|
|
std::vector<AnfNodePtr> dnt_x_outputs; |
|
|
|
CreateMultipleOutputsOfAnfNode(func_graph, dnt_x, kGRUV2HiddenGradCellOutputNum, &dnt_x_outputs); |
|
|
|
concat_inputs.emplace_back(dnt_x_outputs[hidden_grad_output_index["dnt_x"]]); |
|
|
|
} else { |
|
|
|
concat_inputs.emplace_back(dnt_x); |
|
|
|
} |
|
|
|
auto concat_op = func_graph->NewCNode(concat_inputs); |
|
|
|
std::vector<size_t> shape = {t_size, batch_size, kGateNum * hidden_size}; |
|
|
|
auto types = {AnfAlgo::GetOutputInferDataType(dnt_x, 0)}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(types, {shape}, concat_op.get()); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(kConcatNum)), concat_op); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{2}), concat_op); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{kConcatNum}), concat_op); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(kDim2)), concat_op); |
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat_op); |
|
|
|
return concat_op; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr CreateWBroadcastToDNode(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
// node1 : input node |
|
|
|
// node2 : orign_input x |
|
|
|
MS_EXCEPTION_IF_NULL(node1); |
|
|
|
MS_EXCEPTION_IF_NULL(node2); |
|
|
|
// BroadcastTo |
|
|
|
std::vector<AnfNodePtr> braodcast_to_input = {NewValueNode(std::make_shared<Primitive>(kBroadcastToOpName)), node1}; |
|
|
|
auto broadcast_to_d = graph->NewCNode(braodcast_to_input); |
|
|
|
MS_EXCEPTION_IF_NULL(broadcast_to_d); |
|
|
|
size_t t_size = AnfAlgo::GetOutputInferShape(node2, 0)[kDim0]; |
|
|
|
size_t batch = AnfAlgo::GetOutputInferShape(node1, 0)[kDim0]; |
|
|
|
size_t gate_size = AnfAlgo::GetOutputInferShape(node1, 0)[kDim1]; |
|
|
|
std::vector<size_t> shape = {t_size, batch, gate_size}; |
|
|
|
auto type = {AnfAlgo::GetOutputInferDataType(node1, 0)}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(type, {shape}, broadcast_to_d.get()); |
|
|
|
|
|
|
|
std::vector<int64_t> attr_shape = {SizeToLong(t_size), SizeToLong(batch), SizeToLong(gate_size)}; |
|
|
|
AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(attr_shape), broadcast_to_d); |
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), broadcast_to_d); |
|
|
|
return broadcast_to_d; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr CreateDhxBatchMatMul(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { |
|
|
|
AnfNodePtr CreateDwxBatchMatMul(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(node1); |
|
|
|
MS_EXCEPTION_IF_NULL(node2); |
|
|
|
@@ -217,45 +365,57 @@ AnfNodePtr CreateDhxBatchMatMul(const FuncGraphPtr &graph, const AnfNodePtr &nod |
|
|
|
node1, node2}; |
|
|
|
auto batch_matmul = graph->NewCNode(matmul_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(batch_matmul); |
|
|
|
std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node1, 0)[kDim0], |
|
|
|
AnfAlgo::GetOutputInferShape(node1, 0)[kDim2], |
|
|
|
AnfAlgo::GetOutputInferShape(node2, 0)[kDim2]}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {shape}, batch_matmul.get()); |
|
|
|
std::vector<size_t> shape = {t_size, input_size, kGateNum * hidden_size}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({dh_dtype}, {shape}, batch_matmul.get()); |
|
|
|
AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(true), batch_matmul); |
|
|
|
AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(false), batch_matmul); |
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), batch_matmul); |
|
|
|
return batch_matmul; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr CreateDwhBatchMatMul(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(node1); |
|
|
|
MS_EXCEPTION_IF_NULL(node2); |
|
|
|
// BatchMatMul |
|
|
|
AnfNodePtr CreateDxtBatchMatMul(const FuncGraphPtr &func_graph, const AnfNodePtr &dgate_concat, |
|
|
|
const AnfNodePtr &weight_input, const AnfNodePtr &dx) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(dgate_concat); |
|
|
|
MS_EXCEPTION_IF_NULL(weight_input); |
|
|
|
MS_EXCEPTION_IF_NULL(dx); |
|
|
|
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMul->name())), |
|
|
|
node1, node2}; |
|
|
|
auto batch_matmul = graph->NewCNode(matmul_inputs); |
|
|
|
dgate_concat, weight_input}; |
|
|
|
auto batch_matmul = func_graph->NewCNode(matmul_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(batch_matmul); |
|
|
|
std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node1, 0)[kDim0], |
|
|
|
AnfAlgo::GetOutputInferShape(node1, 0)[kDim1], |
|
|
|
AnfAlgo::GetOutputInferShape(node2, 0)[kDim1]}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {shape}, batch_matmul.get()); |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dx, 0)}, {AnfAlgo::GetOutputInferShape(dx, 0)}, |
|
|
|
batch_matmul.get()); |
|
|
|
AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(false), batch_matmul); |
|
|
|
AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(true), batch_matmul); |
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), batch_matmul); |
|
|
|
return batch_matmul; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr CreateDwReduceSumDNode(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &node2) { |
|
|
|
AnfNodePtr CreateWBroadcastToDNode(const FuncGraphPtr &graph, const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
// BroadcastTo |
|
|
|
std::vector<AnfNodePtr> braodcast_to_input = {NewValueNode(std::make_shared<Primitive>(kBroadcastToOpName)), node}; |
|
|
|
auto broadcast_to_d = graph->NewCNode(braodcast_to_input); |
|
|
|
std::vector<size_t> shape = {t_size, input_size, kGateNum * hidden_size}; |
|
|
|
auto type = {AnfAlgo::GetOutputInferDataType(node, 0)}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(type, {shape}, broadcast_to_d.get()); |
|
|
|
std::vector<int64_t> attr_shape = {SizeToLong(t_size), SizeToLong(input_size), SizeToLong(kGateNum * hidden_size)}; |
|
|
|
AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(attr_shape), broadcast_to_d); |
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), broadcast_to_d); |
|
|
|
return broadcast_to_d; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr CreateDwReduceSumDNode(const FuncGraphPtr &graph, const AnfNodePtr &matmul, const AnfNodePtr &gru_grad) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(matmul); |
|
|
|
MS_EXCEPTION_IF_NULL(gru_grad); |
|
|
|
// ReduceSumD for dw_x and dw_h |
|
|
|
std::vector<AnfNodePtr> reducesum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())), |
|
|
|
node}; |
|
|
|
matmul}; |
|
|
|
auto reduce_sumd = graph->NewCNode(reducesum_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(reduce_sumd); |
|
|
|
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; |
|
|
|
auto shapes = {AnfAlgo::GetOutputInferShape(node2, 0)}; |
|
|
|
auto types = {AnfAlgo::GetOutputInferDataType(gru_grad, 0)}; |
|
|
|
auto shapes = {AnfAlgo::GetOutputInferShape(gru_grad, 0)}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, reduce_sumd.get()); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0}), reduce_sumd); |
|
|
|
AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_sumd); |
|
|
|
@@ -272,9 +432,8 @@ AnfNodePtr CreateDbReduceSumDNode(const FuncGraphPtr &graph, const AnfNodePtr &n |
|
|
|
node}; |
|
|
|
auto reduce_sumd = graph->NewCNode(reducesum_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(reduce_sumd); |
|
|
|
|
|
|
|
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; |
|
|
|
std::vector<size_t> shape = {kGateNum * AnfAlgo::GetOutputInferShape(node2, 0)[kDim1]}; |
|
|
|
std::vector<size_t> shape = {kGateNum * hidden_size}; |
|
|
|
auto types = {AnfAlgo::GetOutputInferDataType(node2, 0)}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(types, {shape}, reduce_sumd.get()); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0, 1}), reduce_sumd); |
|
|
|
AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_sumd); |
|
|
|
@@ -299,52 +458,76 @@ const AnfNodePtr DynamicGRUV2GradFission::Process(const FuncGraphPtr &func_graph |
|
|
|
<< kDynamicGRUV2GradInputNum << " inputs"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (AnfAlgo::IsDynamicShape(node)) { |
|
|
|
MS_LOG(INFO) << "DynamicGRUV2Grad is dynamic shape, can not optimizer."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
// input_list of dynamic_gru_v2_grad |
|
|
|
const auto &ori_inputs = dynamic_gru_v2_grad_cnode->inputs(); |
|
|
|
// add gru_v2_gru_hidden |
|
|
|
auto gru_v2_gru_hidden = CreateGRUV2HiddenGradNode(func_graph, dynamic_gru_v2_grad_cnode); |
|
|
|
std::vector<AnfNodePtr> gru_hidden_outputs; |
|
|
|
CreateMultipleOutputsOfAnfNode(func_graph, gru_v2_gru_hidden, kGRUV2HiddenGradOutputNum, &gru_hidden_outputs); |
|
|
|
size_t step_num = AnfAlgo::GetOutputInferShape(ori_inputs[kIndex1], 0)[kDim0]; |
|
|
|
AnfNodePtr dwh_batch_matmul = nullptr; |
|
|
|
if (step_num != 1) { |
|
|
|
std::vector<AnfNodePtr> gru_grad_outputs; |
|
|
|
CreateMultipleOutputsOfAnfNode(func_graph, dynamic_gru_v2_grad_cnode, kDynamicGRUV2GradOutputNum, &gru_grad_outputs); |
|
|
|
auto input_h = ori_inputs[input_index["h"]]; |
|
|
|
auto input_x = ori_inputs[input_index["x"]]; |
|
|
|
t_size = AnfAlgo::GetOutputInferShape(input_h, 0)[kDim0]; |
|
|
|
batch_size = AnfAlgo::GetOutputInferShape(input_h, 0)[kDim1]; |
|
|
|
hidden_size = AnfAlgo::GetOutputInferShape(input_h, 0)[kDim2]; |
|
|
|
input_size = AnfAlgo::GetOutputInferShape(input_x, 0)[kDim2]; |
|
|
|
MS_LOG(INFO) << "For DynamicGRUV2Grad op, t_size: " << t_size << ", batch_size: " << batch_size |
|
|
|
<< ", hidden_size: " << hidden_size << ", input_size: " << input_size; |
|
|
|
// add GRUHiddenGrad {dhPrevNode, dgateHConcatTNode, dntXConcatTNode} |
|
|
|
std::vector<AnfNodePtr> gru_hidden_grad_nodes = AddGRUHiddenGradNode(func_graph, dynamic_gru_v2_grad_cnode); |
|
|
|
AnfNodePtr dwh_matmul_node; |
|
|
|
auto dgate_h = gru_hidden_grad_nodes[hidden_grad_output_index["dgate_h"]]; |
|
|
|
if (t_size != 1) { |
|
|
|
// split h |
|
|
|
auto h_split = CreateHSplitVDNode(func_graph, ori_inputs[kIndex6]); |
|
|
|
auto split = AddHSplitNode(func_graph, dynamic_gru_v2_grad_cnode); |
|
|
|
// concat(h, h_split) |
|
|
|
auto h_concat = CreateHConcatDNode(func_graph, ori_inputs[kIndex5], h_split); |
|
|
|
// batchmatmul(h_concat.T, dgate_h) |
|
|
|
dwh_batch_matmul = CreateDhxBatchMatMul(func_graph, h_concat, gru_hidden_outputs[kIndex1]); |
|
|
|
auto h_concat = AddHConcatNode(func_graph, dynamic_gru_v2_grad_cnode, split); |
|
|
|
// add matmul(h_prev.T, dgate_h) |
|
|
|
dwh_matmul_node = AddDwhMatmulNode(func_graph, dgate_h, h_concat); |
|
|
|
} else { |
|
|
|
auto reshape = CreateHReshape(func_graph, ori_inputs[kIndex5]); |
|
|
|
// batchmatmul(init_h.T, dgate_h) |
|
|
|
dwh_batch_matmul = CreateDhxBatchMatMul(func_graph, reshape, gru_hidden_outputs[kIndex1]); |
|
|
|
auto reshape = CreateHReshape(func_graph, ori_inputs[input_index["init_h"]]); |
|
|
|
dwh_matmul_node = AddDwhMatmulNode(func_graph, dgate_h, reshape); |
|
|
|
} |
|
|
|
// split dgate_h |
|
|
|
auto dgate_h_split = CreateDgateHSplitVDNode(func_graph, gru_hidden_outputs[kIndex1]); |
|
|
|
// split dgate_h to [dit, drt] and [dnt_h] |
|
|
|
auto dgate_h_split = CreateDgateHSplitVDNode(func_graph, dgate_h); |
|
|
|
// concat(dgate_h_split[0], dnt_x) to dgate_x |
|
|
|
auto dgate_x_concat = CreateDgateXConcatDNode(func_graph, dgate_h_split, gru_hidden_outputs[kIndex2]); |
|
|
|
auto dgate_x_concat = |
|
|
|
CreateDgateXConcatDNode(func_graph, dgate_h_split, gru_hidden_grad_nodes[hidden_grad_output_index["dnt_x"]]); |
|
|
|
// broadcast weight_input [input_size, 3 * hidden_size] to [t_size, input_size, 3 * hidden_size] |
|
|
|
auto w_input_broadcast = CreateWBroadcastToDNode(func_graph, ori_inputs[kIndex2], ori_inputs[kIndex1]); |
|
|
|
// batchmatmul(x.T, dgate_x_concat) |
|
|
|
auto dwx_batch_matmul = CreateDhxBatchMatMul(func_graph, ori_inputs[kIndex1], dgate_x_concat); |
|
|
|
auto w_input_broadcast = CreateWBroadcastToDNode(func_graph, ori_inputs[input_index["weight_input"]]); |
|
|
|
// batchmatmul(dgate_x_concat, w_input_broadcast.T) |
|
|
|
auto dxt_batch_matmul = CreateDwhBatchMatMul(func_graph, dgate_x_concat, w_input_broadcast); |
|
|
|
auto dxt_batch_matmul = |
|
|
|
CreateDxtBatchMatMul(func_graph, dgate_x_concat, w_input_broadcast, gru_grad_outputs[output_index["dx"]]); |
|
|
|
// batchmatmul(x.T, dgate_x_concat) |
|
|
|
auto dwx_batch_matmul = CreateDwxBatchMatMul(func_graph, ori_inputs[input_index["x"]], dgate_x_concat); |
|
|
|
// reducesum dw_x and dw_h |
|
|
|
auto dwx_reduce_sum = CreateDwReduceSumDNode(func_graph, dwx_batch_matmul, ori_inputs[kIndex2]); |
|
|
|
auto dwh_reduce_sum = CreateDwReduceSumDNode(func_graph, dwh_batch_matmul, ori_inputs[kIndex3]); |
|
|
|
auto dwx_reduce_sum = |
|
|
|
CreateDwReduceSumDNode(func_graph, dwx_batch_matmul, gru_grad_outputs[output_index["dw_input"]]); |
|
|
|
auto dwh_reduce_sum = |
|
|
|
CreateDwReduceSumDNode(func_graph, dwh_matmul_node, gru_grad_outputs[output_index["dw_hidden"]]); |
|
|
|
// reducesum db_x and db_h |
|
|
|
auto dbx_reduce_sum = CreateDbReduceSumDNode(func_graph, dgate_x_concat, ori_inputs[kIndex5]); |
|
|
|
auto dbh_reduce_sum = CreateDbReduceSumDNode(func_graph, gru_hidden_outputs[kIndex1], ori_inputs[kIndex5]); |
|
|
|
AnfNodePtr dbh_reduce_sum; |
|
|
|
if (t_size == 1) { |
|
|
|
std::vector<AnfNodePtr> dbh_outputs; |
|
|
|
CreateMultipleOutputsOfAnfNode(func_graph, dgate_h, kGRUV2HiddenGradCellOutputNum, &dbh_outputs); |
|
|
|
dbh_reduce_sum = CreateDbReduceSumDNode(func_graph, dbh_outputs[kIndex1], ori_inputs[kIndex5]); |
|
|
|
} else { |
|
|
|
dbh_reduce_sum = CreateDbReduceSumDNode(func_graph, dgate_h, ori_inputs[kIndex5]); |
|
|
|
} |
|
|
|
std::vector<AnfNodePtr> dh_prev_outputs; |
|
|
|
CreateMultipleOutputsOfAnfNode(func_graph, gru_hidden_grad_nodes[kIndex0], kGRUV2HiddenGradCellOutputNum, |
|
|
|
&dh_prev_outputs); |
|
|
|
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), |
|
|
|
dwx_reduce_sum, |
|
|
|
dwh_reduce_sum, |
|
|
|
dbx_reduce_sum, |
|
|
|
dbh_reduce_sum, |
|
|
|
dxt_batch_matmul, |
|
|
|
gru_hidden_outputs[kIndex0]}; |
|
|
|
dh_prev_outputs[kIndex0]}; |
|
|
|
auto make_tuple = func_graph->NewCNode(make_tuple_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(make_tuple); |
|
|
|
return make_tuple; |
|
|
|
} |
|
|
|
} // namespace opt |
|
|
|
|