From 662976a75d4ec4df02438d9b231fc3642790b86c Mon Sep 17 00:00:00 2001 From: liubuyu Date: Sat, 31 Oct 2020 14:22:11 +0800 Subject: [PATCH] dynamic rnn fission pass v2 --- .../ascend/ascend_backend_optimization.cc | 6 +- .../format_type/dynamic_rnn_grad_reformat.cc | 80 +++ .../format_type/dynamic_rnn_grad_reformat.h | 41 ++ .../ir_fission/dynamic_rnn_grad_fission.cc | 250 --------- .../ir_fission/dynamic_rnn_grad_fission_v2.cc | 483 ++++++++++++++++++ ...ission.h => dynamic_rnn_grad_fission_v2.h} | 15 +- mindspore/ccsrc/utils/utils.h | 3 + mindspore/ops/_op_impl/tbe/__init__.py | 1 + .../tbe/basic_lstm_cell_c_state_grad_v2.py | 51 ++ 9 files changed, 671 insertions(+), 259 deletions(-) create mode 100644 mindspore/ccsrc/backend/optimizer/ascend/format_type/dynamic_rnn_grad_reformat.cc create mode 100644 mindspore/ccsrc/backend/optimizer/ascend/format_type/dynamic_rnn_grad_reformat.h delete mode 100644 mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission.cc create mode 100644 mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.cc rename mindspore/ccsrc/backend/optimizer/ascend/ir_fission/{dynamic_rnn_grad_fission.h => dynamic_rnn_grad_fission_v2.h} (77%) create mode 100644 mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index 88934d6844..7924d06888 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -19,7 +19,7 @@ #include #include #include "backend/optimizer/common/optimizer.h" -#include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission.h" +#include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.h" #include "backend/optimizer/ascend/ir_fission/bn_split.h" #include "backend/optimizer/ascend/ir_fission/bn_grad_split.h" #include "backend/optimizer/ascend/ir_fission/batch_norm_grad_split.h" @@ -61,6 +61,7 @@ #include "backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h" #include "backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h" #include "backend/optimizer/ascend/format_type/insert_trans_op.h" +#include "backend/optimizer/ascend/format_type/dynamic_rnn_grad_reformat.h" #include "backend/optimizer/ascend/format_type/insert_transpose_for_basiclstm_op.h" #include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h" #include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h" @@ -215,6 +216,7 @@ void AscendDataLayout(const std::shared_ptr &kernel_graph) auto optimizer = std::make_shared(); auto data_layout_pm = std::make_shared("transop_pm"); data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); @@ -276,7 +278,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptrAddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); AddAscendIRFusionRulesPass(ir_fusion_pm.get()); AddAscendIRFusionPass(ir_fusion_pm.get()); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/dynamic_rnn_grad_reformat.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/dynamic_rnn_grad_reformat.cc new file mode 100644 index 0000000000..7c3dfa1f06 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/dynamic_rnn_grad_reformat.cc @@ -0,0 +1,80 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/format_type/dynamic_rnn_grad_reformat.h" +#include +#include "backend/optimizer/ascend/ascend_helper.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "base/core_ops.h" + +namespace mindspore { +namespace opt { +const BaseRef DynamicRNNGradReformat::DefinePattern() const { + VarPtr Xs = std::make_shared(); + VarPtr Xs2 = std::make_shared(); + MS_EXCEPTION_IF_NULL(Xs); + MS_EXCEPTION_IF_NULL(Xs2); + const auto split = std::make_shared(prim::kPrimSplitV->name()); + return VectorRef({split, VectorRef({std::make_shared(prim::kPrimMatMul->name()), Xs, Xs2})}); +} + +const AnfNodePtr DynamicRNNGradReformat::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto split_v = node->cast(); + MS_EXCEPTION_IF_NULL(split_v); + auto matmul = CheckAnfNodeIfCNodeAndInputSize(split_v->input(1), 3); + MS_EXCEPTION_IF_NULL(matmul); + auto input_node_with_idx = AnfAlgo::GetPrevNodeOutput(matmul, 0); + auto input_node = input_node_with_idx.first; + MS_EXCEPTION_IF_NULL(input_node); + if (!(input_node->isa() && + AnfAlgo::GetCNodeName(input_node->cast()) == kBasicLSTMCellCStateGradV2OpName)) { + return nullptr; + } + + // reformat matmul + auto matmul_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(matmul); + MS_EXCEPTION_IF_NULL(matmul_kernel_build_info); + auto matmul_new_builder = std::make_shared(); + matmul_new_builder->SetInputsFormat({kOpFormat_FRAC_NZ, kOpFormat_FRAC_NZ}); + matmul_new_builder->SetOutputsFormat({kOpFormat_FRAC_NZ}); + matmul_new_builder->SetInputsDeviceType({kNumberTypeFloat16, kNumberTypeFloat16}); + matmul_new_builder->SetOutputsDeviceType({kNumberTypeFloat}); + matmul_new_builder->SetKernelType(matmul_kernel_build_info->kernel_type()); + matmul_new_builder->SetFusionType(matmul_kernel_build_info->fusion_type()); + matmul_new_builder->SetProcessor(matmul_kernel_build_info->processor()); + AnfAlgo::SetSelectKernelBuildInfo(matmul_new_builder->Build(), matmul.get()); + AnfAlgo::SetNodeAttr("insert_backend", MakeValue(true), matmul); + + // reformat split_v + auto split_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(split_v); + MS_EXCEPTION_IF_NULL(split_kernel_build_info); + auto split_new_builder = std::make_shared(); + split_new_builder->SetInputsFormat({kOpFormat_FRAC_NZ}); + split_new_builder->SetOutputsFormat({kOpFormat_FRAC_NZ, kOpFormat_FRAC_NZ}); + split_new_builder->SetInputsDeviceType(split_kernel_build_info->GetAllInputDeviceTypes()); + split_new_builder->SetOutputsDeviceType(split_kernel_build_info->GetAllOutputDeviceTypes()); + split_new_builder->SetKernelType(split_kernel_build_info->kernel_type()); + split_new_builder->SetFusionType(split_kernel_build_info->fusion_type()); + split_new_builder->SetProcessor(split_kernel_build_info->processor()); + AnfAlgo::SetSelectKernelBuildInfo(split_new_builder->Build(), split_v.get()); + AnfAlgo::SetNodeAttr("insert_backend", MakeValue(true), split_v); + return split_v; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/dynamic_rnn_grad_reformat.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/dynamic_rnn_grad_reformat.h new file mode 100644 index 0000000000..bdd00bacc3 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/dynamic_rnn_grad_reformat.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DYNAMIC_RNN_GRAD_REFORMAT_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DYNAMIC_RNN_GRAD_REFORMAT_H_ +#include +#include +#include +#include +#include "ir/anf.h" +#include "backend/optimizer/common/pattern_engine.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class DynamicRNNGradReformat : public PatternProcessPass { + public: + explicit DynamicRNNGradReformat(bool multigraph = true) + : PatternProcessPass("dynamic_rnn_grad_reformat", multigraph) {} + ~DynamicRNNGradReformat() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DYNAMIC_RNN_GRAD_REFORMAT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission.cc deleted file mode 100644 index 2b15a04c0d..0000000000 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission.cc +++ /dev/null @@ -1,250 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission.h" -#include -#include -#include -#include "backend/session/anf_runtime_algorithm.h" -#include "backend/optimizer/common/helper.h" - -namespace mindspore { -namespace opt { -constexpr size_t kDynamicRNNGradInputNum = 16; -constexpr size_t kLSTMInputGradOutputNum = 4; -const BaseRef DynamicRNNGradFission::DefinePattern() const { - VarPtr Xs = std::make_shared(); - return VectorRef({prim::kPrimDynamicRNNGrad, Xs}); -} - -AnfNodePtr CreateSplitVD(const FuncGraphPtr &graph, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - // SplitV - std::vector splitvd_input = {NewValueNode(std::make_shared(prim::kPrimSplitV->name())), node}; - auto split_vd = graph->NewCNode(splitvd_input); - MS_EXCEPTION_IF_NULL(split_vd); - auto dtypes = {AnfAlgo::GetOutputInferDataType(node, 0), AnfAlgo::GetOutputInferDataType(node, 0)}; - std::vector shape = {AnfAlgo::GetOutputInferShape(node, 0)[0] - 1, AnfAlgo::GetOutputInferShape(node, 0)[1], - AnfAlgo::GetOutputInferShape(node, 0)[2]}; - auto shape2 = {IntToSize(1), AnfAlgo::GetOutputInferShape(node, 0)[1], AnfAlgo::GetOutputInferShape(node, 0)[2]}; - std::vector> shapes = {shape, shape2}; - AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_vd.get()); - AnfAlgo::SetNodeAttr("split_dim", MakeValue(0), split_vd); - AnfAlgo::SetNodeAttr("num_split", MakeValue(2), split_vd); - int tmp = SizeToInt(AnfAlgo::GetOutputInferShape(node, 0)[0]) - 1; - AnfAlgo::SetNodeAttr("size_splits", MakeValue(std::vector{tmp, 1}), split_vd); - AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), split_vd); - return split_vd; -} - -AnfNodePtr CreateLSTMInputGrad(const FuncGraphPtr &graph, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - const auto &dynamic_rnn_grad_inputs = cnode->inputs(); - std::vector lstm_input_grad_inputs = {NewValueNode(std::make_shared(kLSTMInputGradOpName)), - dynamic_rnn_grad_inputs[2], - dynamic_rnn_grad_inputs[6], - dynamic_rnn_grad_inputs[8], - dynamic_rnn_grad_inputs[9], - dynamic_rnn_grad_inputs[10], - dynamic_rnn_grad_inputs[11], - dynamic_rnn_grad_inputs[12], - dynamic_rnn_grad_inputs[13], - dynamic_rnn_grad_inputs[14], - dynamic_rnn_grad_inputs[15], - dynamic_rnn_grad_inputs[16]}; - std::vector ori_outputs; - CreateMultipleOutputsOfAnfNode(graph, node, 5, &ori_outputs); - auto lstm_op = graph->NewCNode(lstm_input_grad_inputs); - MS_EXCEPTION_IF_NULL(lstm_op); - auto ori_type = AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_inputs[8], 0); - auto types = {AnfAlgo::GetOutputInferDataType(ori_outputs[2], 0), AnfAlgo::GetOutputInferDataType(ori_outputs[3], 0), - AnfAlgo::GetOutputInferDataType(ori_outputs[4], 0), ori_type}; - std::vector ori_shape = {AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_inputs[8], 0)[0], - AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_inputs[8], 0)[1], - 4 * AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_inputs[8], 0)[2]}; - auto shapes = {AnfAlgo::GetOutputInferShape(ori_outputs[2], 0), AnfAlgo::GetOutputInferShape(ori_outputs[3], 0), - AnfAlgo::GetOutputInferShape(ori_outputs[4], 0), ori_shape}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, lstm_op.get()); - return lstm_op; -} - -AnfNodePtr CreateBatchMatMul(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node1); - MS_EXCEPTION_IF_NULL(node2); - // BatchMatMul - std::vector matmul_inputs = {NewValueNode(std::make_shared(prim::kPrimBatchMatMul->name())), - node2, node1}; - auto batch_matmul = graph->NewCNode(matmul_inputs); - MS_EXCEPTION_IF_NULL(batch_matmul); - auto types = {AnfAlgo::GetOutputInferDataType(node1, 0)}; - std::vector shape = {AnfAlgo::GetOutputInferShape(node2, 0)[0], AnfAlgo::GetOutputInferShape(node2, 0)[2], - AnfAlgo::GetOutputInferShape(node1, 0)[2]}; - auto shapes = {shape}; - AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), batch_matmul); - AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(true), batch_matmul); - AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(false), batch_matmul); - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, batch_matmul.get()); - return batch_matmul; -} - -AnfNodePtr AddHConcatD(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node1); - MS_EXCEPTION_IF_NULL(node2); - std::vector ori_outputs; - CreateMultipleOutputsOfAnfNode(graph, node2, 2, &ori_outputs); - auto ori_shape = AnfAlgo::GetOutputInferShape(node1, 0); - std::vector> shape_tmp; - if (ori_shape.size() == 3) { - shape_tmp = {ori_shape}; - } else { - shape_tmp = {{IntToSize(1), ori_shape[0], ori_shape[1]}}; - } - auto ori_dtype = {AnfAlgo::GetOutputInferDataType(node1, 0)}; - // reshape - std::vector reshape_input = {NewValueNode(std::make_shared(prim::kPrimReshape->name())), - node1}; - auto reshape = graph->NewCNode(reshape_input); - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape); - AnfAlgo::SetOutputInferTypeAndShape(ori_dtype, shape_tmp, reshape.get()); - - // concatd --> concat - std::vector concat_inputs = {NewValueNode(std::make_shared(prim::kPrimConcat->name())), - reshape, ori_outputs[0]}; - auto concat_op = graph->NewCNode(concat_inputs); - MS_EXCEPTION_IF_NULL(concat_op); - std::vector input = {AnfAlgo::GetOutputInferShape(node2, 0)[0] + 1, AnfAlgo::GetOutputInferShape(node2, 0)[1], - AnfAlgo::GetOutputInferShape(node2, 0)[2]}; - auto types = {AnfAlgo::GetOutputInferDataType(node1, 0)}; - auto shapes = {input}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, concat_op.get()); - AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat_op); - AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector{2}), concat_op); - AnfAlgo::SetNodeAttr("axis", MakeValue(0), concat_op); - AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat_op); - return concat_op; -} - -AnfNodePtr AddConcatD(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node1); - MS_EXCEPTION_IF_NULL(node2); - // concatd --> concat - std::vector concat_inputs = {NewValueNode(std::make_shared(prim::kPrimConcat->name())), node1, - node2}; - auto concat_op = graph->NewCNode(concat_inputs); - MS_EXCEPTION_IF_NULL(concat_op); - std::vector input = {AnfAlgo::GetOutputInferShape(node1, 0)[0], AnfAlgo::GetOutputInferShape(node1, 0)[1], - AnfAlgo::GetOutputInferShape(node1, 0)[2] + AnfAlgo::GetOutputInferShape(node2, 0)[2]}; - auto types = {AnfAlgo::GetOutputInferDataType(node1, 0)}; - auto shapes = {input}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, concat_op.get()); - AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat_op); - AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector{2}), concat_op); - AnfAlgo::SetNodeAttr("axis", MakeValue(2), concat_op); - AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat_op); - return concat_op; -} - -AnfNodePtr AddDwReduceSum(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { - // node1 : dynamic output - // node2 : matmul - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node1); - MS_EXCEPTION_IF_NULL(node2); - std::vector ori_outputs; - CreateMultipleOutputsOfAnfNode(graph, node1, 5, &ori_outputs); - // ReduceSumd - std::vector reducesum_inputs = {NewValueNode(std::make_shared(prim::kPrimReduceSum->name())), - node2}; - auto reduce_sumd = graph->NewCNode(reducesum_inputs); - MS_EXCEPTION_IF_NULL(reduce_sumd); - auto types = {AnfAlgo::GetOutputInferDataType(ori_outputs[0], 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(ori_outputs[0], 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, reduce_sumd.get()); - AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector{0}), reduce_sumd); - AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_sumd); - AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sumd); - return reduce_sumd; -} - -AnfNodePtr AddDbReduceSum(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { - // node1 lstm output - // node2 // dynamic output - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node1); - MS_EXCEPTION_IF_NULL(node2); - std::vector ori_outputs; - CreateMultipleOutputsOfAnfNode(graph, node2, 5, &ori_outputs); - // ReduceSumd --> ReduceSum - std::vector reducerum_inputs = {NewValueNode(std::make_shared(prim::kPrimReduceSum->name())), - node1}; - auto reduce_sumd = graph->NewCNode(reducerum_inputs); - MS_EXCEPTION_IF_NULL(reduce_sumd); - auto types = {AnfAlgo::GetOutputInferDataType(ori_outputs[1], 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(ori_outputs[1], 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, reduce_sumd.get()); - AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector{0, 1}), reduce_sumd); - AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_sumd); - AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sumd); - return reduce_sumd; -} - -const AnfNodePtr DynamicRNNGradFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->size() < kDynamicRNNGradInputNum + 1) { - MS_LOG(INFO) << "The input num of DynamicRNNGrad less than" << kDynamicRNNGradInputNum - << ". The node should not be changed"; - return nullptr; - } - // input_list of dynamic_rnn_grad - const auto &ori_inputs = cnode->inputs(); - // create split_vd - auto split_vd = CreateSplitVD(func_graph, ori_inputs[7]); - // create concat_1 - auto h_concat = AddHConcatD(func_graph, ori_inputs[5], split_vd); - // create concat_2 - auto concat = AddConcatD(func_graph, ori_inputs[1], h_concat); - // create lsym_input_grad - auto lstm_input_grad = CreateLSTMInputGrad(func_graph, cnode); - std::vector lstm_outputs; - CreateMultipleOutputsOfAnfNode(func_graph, lstm_input_grad, kLSTMInputGradOutputNum, &lstm_outputs); - // create matmul - auto batch_matmul = CreateBatchMatMul(func_graph, lstm_outputs[3], concat); - // create reduce_sum_1 - auto dw_reduce_sum = AddDwReduceSum(func_graph, node, batch_matmul); - // create reduce_sum_2 - auto db_reduce_sum = AddDbReduceSum(func_graph, lstm_outputs[3], node); - std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), - dw_reduce_sum, - db_reduce_sum, - lstm_outputs[0], - lstm_outputs[1], - lstm_outputs[2]}; - auto make_tuple = func_graph->NewCNode(make_tuple_inputs); - MS_EXCEPTION_IF_NULL(make_tuple); - return make_tuple; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.cc new file mode 100644 index 0000000000..a3759adb4b --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.cc @@ -0,0 +1,483 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr size_t kDynamicRNNGradInputNum = 16; +constexpr size_t kSplitVOutputNum = 2; +constexpr size_t kLSTMInputGradOutputNum = 4; + +void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode, + std::vector> *result_nodes) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode); + MS_EXCEPTION_IF_NULL(result_nodes); + std::vector basic_lstm_cell_c_state_grad_nodes; + std::vector matmul_nodes; + std::vector split_nodes; + // Get the size of t + auto origin_input9_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(10), 0); + size_t t_size = origin_input9_shape[0]; + auto input_i_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(12), 0); + + for (size_t i = 0; i < t_size; ++i) { + // Create basic_lstm_cell_c_state_grad + std::vector basic_lstm_cell_c_state_grad_inputs = { + NewValueNode(std::make_shared(kBasicLSTMCellCStateGradV2OpName))}; + auto basic_lstm_cell_c_state_grad = func_graph->NewCNode(basic_lstm_cell_c_state_grad_inputs); + + std::vector output0_dims{origin_input9_shape[0], 4 * (((origin_input9_shape[1] + 15) / 16) * 16)}; + std::vector output1_dims{input_i_shape[1], input_i_shape[2]}; + AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32}, {output0_dims, output1_dims}, + basic_lstm_cell_c_state_grad.get()); + AnfAlgo::SetNodeAttr("forget_bias", MakeValue(1.0f), basic_lstm_cell_c_state_grad); + AnfAlgo::SetNodeAttr("activation", MakeValue("Tanh"), basic_lstm_cell_c_state_grad); + + // Create matmul + auto origin_input1_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(2), 0); + std::vector matmul_inputs = {NewValueNode(std::make_shared(prim::kPrimMatMul->name()))}; + auto matmul = func_graph->NewCNode(matmul_inputs); + AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {{output0_dims[0], origin_input1_shape[0]}}, + matmul.get()); + AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(false), matmul); + AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(true), matmul); + + // Create split + std::vector splitv_input = {NewValueNode(std::make_shared(prim::kPrimSplitV->name()))}; + auto split_v = func_graph->NewCNode(splitv_input); + auto origin_output2_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 2); + auto origin_output3_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 3); + std::vector split_v_output0_shape{origin_output2_shape[1], origin_output2_shape[2]}; + std::vector split_v_output1_shape{origin_output3_shape[0], origin_output3_shape[1]}; + AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32}, + {split_v_output0_shape, split_v_output1_shape}, split_v.get()); + + AnfAlgo::SetNodeAttr(kAttrSizeSplits, + MakeValue(std::vector{SizeToInt((origin_output2_shape[2] + 15) / 16), + SizeToInt((origin_output3_shape[1] + 15) / 16)}), + split_v); + AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(0), split_v); + AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(2), split_v); + + basic_lstm_cell_c_state_grad_nodes.emplace_back(basic_lstm_cell_c_state_grad); + matmul_nodes.emplace_back(matmul); + split_nodes.emplace_back(split_v); + } + result_nodes->emplace_back(basic_lstm_cell_c_state_grad_nodes); + result_nodes->emplace_back(matmul_nodes); + result_nodes->emplace_back(split_nodes); +} + +AnfNodePtr CreateLSTMSPlitV(const FuncGraphPtr &func_graph, const AnfNodePtr &input, + const std::vector> &split_shapes, + const std::vector &split_types, const std::vector &size_split, + size_t num_split_x) { + std::vector lstm_split_input = {NewValueNode(std::make_shared(prim::kPrimSplitV->name())), + input}; + auto lstm_split = func_graph->NewCNode(lstm_split_input); + AnfAlgo::SetOutputInferTypeAndShape(split_types, split_shapes, lstm_split.get()); + AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(size_split), lstm_split); + AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(0), lstm_split); + AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(SizeToInt(num_split_x)), lstm_split); + return lstm_split; +} + +AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode, + std::vector *outputs) { + std::vector> result_nodes; + CreateTLoopNode(func_graph, dynamic_rnn_grad_cnode, &result_nodes); + + auto origin_input5_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(6), 0); + std::vector split_c_dims{1, origin_input5_shape[0], origin_input5_shape[1]}; + + auto origin_input7 = dynamic_rnn_grad_cnode->input(8); + size_t num_split_x = AnfAlgo::GetOutputInferShape(origin_input7, 0)[0]; + std::vector> split_shapes; + std::vector split_types; + std::vector size_split; + for (size_t i = 0; i < num_split_x; ++i) { + split_shapes.emplace_back(split_c_dims); + split_types.emplace_back(kNumberTypeFloat32); + size_split.emplace_back(1); + } + // Create lstm_split_c + auto lstm_split_c = CreateLSTMSPlitV(func_graph, origin_input7, split_shapes, split_types, size_split, num_split_x); + std::vector lstm_split_c_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_c, num_split_x, &lstm_split_c_outputs); + + // Create lstm_split_dy + auto lstm_split_dy = + CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(9), split_shapes, split_types, size_split, num_split_x); + std::vector lstm_split_dy_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_dy, num_split_x, &lstm_split_dy_outputs); + + // Create lstm_split_i + auto lstm_split_i = + CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(12), split_shapes, split_types, size_split, num_split_x); + std::vector lstm_split_i_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_i, num_split_x, &lstm_split_i_outputs); + + // Create lstm_split_j + auto lstm_split_j = + CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(13), split_shapes, split_types, size_split, num_split_x); + std::vector lstm_split_j_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_j, num_split_x, &lstm_split_j_outputs); + + // Create lstm_split_f + auto lstm_split_f = + CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(14), split_shapes, split_types, size_split, num_split_x); + std::vector lstm_split_f_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_f, num_split_x, &lstm_split_f_outputs); + + // Create lstm_split_o + auto lstm_split_o = + CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(15), split_shapes, split_types, size_split, num_split_x); + std::vector lstm_split_o_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_o, num_split_x, &lstm_split_o_outputs); + + // Create lstm_split_tanh + auto lstm_split_tanh = + CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(16), split_shapes, split_types, size_split, num_split_x); + std::vector lstm_split_tanh_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_tanh, num_split_x, &lstm_split_tanh_outputs); + + // Add edges + std::vector pre_basic_lstm_cell_c_state_grad_outputs; + std::vector pre_split_outputs; + auto basic_lstm_cell_c_state_grad_nodes = result_nodes[0]; + auto matmul_nodes = result_nodes[1]; + auto split_nodes = result_nodes[2]; + std::vector lstm_x_concat_input(num_split_x + 1); + lstm_x_concat_input[0] = NewValueNode(std::make_shared(prim::kPrimConcat->name())); + std::vector lstm_gage_concat_input(num_split_x + 1); + lstm_gage_concat_input[0] = NewValueNode(std::make_shared(prim::kPrimConcat->name())); + + for (size_t i = 0; i < num_split_x; ++i) { + size_t idx = num_split_x - i - 1; + // Create basic_lstm_cell_c_state_grad + std::vector basic_lstm_cell_c_state_grad_inputs = { + NewValueNode(std::make_shared(kBasicLSTMCellCStateGradV2OpName))}; + if (i == num_split_x - 1) { + std::vector reshape_inputs = {NewValueNode(std::make_shared(prim::kPrimReshape->name())), + dynamic_rnn_grad_cnode->input(6)}; + auto reshape = func_graph->NewCNode(reshape_inputs); + auto reshape_out_shape = {IntToSize(1), AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(6), 0)[0], + AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(6), 0)[1]}; + AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {reshape_out_shape}, reshape.get()); + basic_lstm_cell_c_state_grad_inputs.emplace_back(reshape); + } else { + basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_c_outputs[idx - 1]); + } + basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_dy_outputs[idx]); + if (i == 0) { + basic_lstm_cell_c_state_grad_inputs.emplace_back(dynamic_rnn_grad_cnode->input(10)); + basic_lstm_cell_c_state_grad_inputs.emplace_back(dynamic_rnn_grad_cnode->input(11)); + } else { + basic_lstm_cell_c_state_grad_inputs.emplace_back(pre_split_outputs[1]); + basic_lstm_cell_c_state_grad_inputs.emplace_back(pre_basic_lstm_cell_c_state_grad_outputs[1]); + } + basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_i_outputs[idx]); + basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_j_outputs[idx]); + basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_f_outputs[idx]); + basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_o_outputs[idx]); + basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_tanh_outputs[idx]); + auto basic_lstm_cell_c_state_grad = func_graph->NewCNode(basic_lstm_cell_c_state_grad_inputs); + MS_EXCEPTION_IF_NULL(basic_lstm_cell_c_state_grad); + basic_lstm_cell_c_state_grad->set_abstract(basic_lstm_cell_c_state_grad_nodes[i]->abstract()); + AnfAlgo::CopyNodeAttrs(basic_lstm_cell_c_state_grad_nodes[i], basic_lstm_cell_c_state_grad); + // Create outputs for current basic_lstm_cell_c_state_grad node + std::vector basic_lstm_cell_c_state_grad_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, basic_lstm_cell_c_state_grad, 2, &basic_lstm_cell_c_state_grad_outputs); + pre_basic_lstm_cell_c_state_grad_outputs = basic_lstm_cell_c_state_grad_outputs; + + // Create MatMul + std::vector matmul_inputs = {NewValueNode(std::make_shared(prim::kPrimMatMul->name()))}; + matmul_inputs.emplace_back(basic_lstm_cell_c_state_grad_outputs[0]); + matmul_inputs.emplace_back(dynamic_rnn_grad_cnode->input(2)); + auto matmul = func_graph->NewCNode(matmul_inputs); + MS_EXCEPTION_IF_NULL(matmul); + matmul->set_abstract(matmul_nodes[i]->abstract()); + AnfAlgo::CopyNodeAttrs(matmul_nodes[i], matmul); + + // Create splitv + std::vector splitv_input = {NewValueNode(std::make_shared(prim::kPrimSplitV->name())), + matmul}; + auto split_v = func_graph->NewCNode(splitv_input); + MS_EXCEPTION_IF_NULL(split_v); + split_v->set_abstract(split_nodes[i]->abstract()); + AnfAlgo::CopyNodeAttrs(split_nodes[i], split_v); + + // Create outputs for current split node + std::vector split_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, split_v, 2, &split_outputs); + pre_split_outputs = split_outputs; + + lstm_x_concat_input[idx + 1] = split_outputs[0]; + lstm_gage_concat_input[idx + 1] = basic_lstm_cell_c_state_grad_outputs[0]; + } + + // Create lstm_x_concat + auto lstm_x_concat = func_graph->NewCNode(lstm_x_concat_input); + AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 2)}, + lstm_x_concat.get()); + AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(num_split_x)), lstm_x_concat); + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector{SizeToInt(num_split_x)}), lstm_x_concat); + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(0), lstm_x_concat); + + // Create lstm_gage_concat + auto lstm_gage_concat = func_graph->NewCNode(lstm_gage_concat_input); + auto origin_input7_shape = AnfAlgo::GetOutputInferShape(origin_input7, 0); + AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, + {{origin_input7_shape[0], origin_input7_shape[1], 4 * origin_input7_shape[2]}}, + lstm_gage_concat.get()); + AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(num_split_x)), lstm_gage_concat); + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector{SizeToInt(num_split_x)}), lstm_gage_concat); + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(0), lstm_gage_concat); + + outputs->emplace_back(lstm_x_concat); + outputs->emplace_back(pre_split_outputs[1]); + outputs->emplace_back(pre_basic_lstm_cell_c_state_grad_outputs[1]); + return lstm_gage_concat; +} + +AnfNodePtr CreateSplitV(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode); + // Create node + auto origin_input6 = dynamic_rnn_grad_cnode->input(7); + std::vector splitv_input = {NewValueNode(std::make_shared(prim::kPrimSplitV->name())), + origin_input6}; + auto split_v = func_graph->NewCNode(splitv_input); + // Set infer data type and shape + auto dtypes = {AnfAlgo::GetOutputInferDataType(origin_input6, 0), AnfAlgo::GetOutputInferDataType(origin_input6, 0)}; + auto origin_input6_shape = AnfAlgo::GetOutputInferShape(origin_input6, 0); + std::vector shape1 = {origin_input6_shape[0] - 1, origin_input6_shape[1], origin_input6_shape[2]}; + std::vector shape2 = {1, origin_input6_shape[1], origin_input6_shape[2]}; + std::vector> shapes = {shape1, shape2}; + AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_v.get()); + // Set attr + AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(0), split_v); + AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(2), split_v); + AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(std::vector{SizeToInt(origin_input6_shape[0] - 1), 1}), split_v); + AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), split_v); + return split_v; +} + +AnfNodePtr CreateHConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode, + const AnfNodePtr &splitv) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode); + MS_EXCEPTION_IF_NULL(splitv); + // Create node + std::vector splitv_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, splitv, kSplitVOutputNum, &splitv_outputs); + if (splitv_outputs.size() != kSplitVOutputNum) { + MS_LOG(EXCEPTION) << "Create outputs of node " << splitv->DebugString() << " failed"; + } + auto origin_input4 = dynamic_rnn_grad_cnode->input(5); + auto origin_input4_shape = AnfAlgo::GetOutputInferShape(origin_input4, 0); + // Create reshape to change shape + std::vector shape_tmp; + if (origin_input4_shape.size() == 3) { + shape_tmp = origin_input4_shape; + } else { + shape_tmp = {1, origin_input4_shape[0], origin_input4_shape[1]}; + } + std::vector reshape_input = {NewValueNode(std::make_shared(prim::kPrimReshape->name())), + origin_input4}; + auto reshape = func_graph->NewCNode(reshape_input); + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape); + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape_tmp}, reshape.get()); + + std::vector concat_inputs = {NewValueNode(std::make_shared(prim::kPrimConcat->name())), + reshape, splitv_outputs[0]}; + auto concat = func_graph->NewCNode(concat_inputs); + // Set infer data type and shape + auto splitv_output0_shape = AnfAlgo::GetOutputInferShape(splitv, 0); + std::vector shape = {splitv_output0_shape[0] + 1, origin_input4_shape[0], origin_input4_shape[1]}; + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape}, concat.get()); + // Set attr + AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat); + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector{2}), concat); + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(0), concat); + AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat); + return concat; +} + +AnfNodePtr CreateConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode, + const AnfNodePtr &h_concat) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode); + // Create node + auto origin_input0 = dynamic_rnn_grad_cnode->input(1); + std::vector concat_inputs = {NewValueNode(std::make_shared(prim::kPrimConcat->name())), + origin_input0, h_concat}; + auto concat = func_graph->NewCNode(concat_inputs); + // Set infer data type and shape + auto origin_output0_shape = AnfAlgo::GetOutputInferShape(origin_input0, 0); + auto h_concat_output_shape = AnfAlgo::GetOutputInferShape(h_concat, 0); + std::vector shape = {origin_output0_shape[0], origin_output0_shape[1], + origin_output0_shape[2] + h_concat_output_shape[2]}; + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input0, 0)}, {shape}, concat.get()); + // Set attr + AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat); + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector{2}), concat); + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(2), concat); + AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat); + return concat; +} + +AnfNodePtr CreateConcatNodeT1(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode); + // Create node + auto origin_input0 = dynamic_rnn_grad_cnode->input(1); + auto origin_input4 = dynamic_rnn_grad_cnode->input(5); + auto origin_input4_shape = AnfAlgo::GetOutputInferShape(origin_input4, 0); + // Create reshape to change shape + std::vector shape_tmp; + if (origin_input4_shape.size() == 3) { + shape_tmp = origin_input4_shape; + } else { + shape_tmp = {1, origin_input4_shape[0], origin_input4_shape[1]}; + } + std::vector reshape_input = {NewValueNode(std::make_shared(prim::kPrimReshape->name())), + origin_input4}; + auto reshape = func_graph->NewCNode(reshape_input); + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape); + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape_tmp}, reshape.get()); + + std::vector concat_inputs = {NewValueNode(std::make_shared(prim::kPrimConcat->name())), + origin_input0, reshape}; + auto concat = func_graph->NewCNode(concat_inputs); + // Set infer data type and shape + auto origin_input0_shape = AnfAlgo::GetOutputInferShape(origin_input0, 0); + std::vector shape = {origin_input0_shape[0], origin_input0_shape[1], origin_input0_shape[2] + shape_tmp[2]}; + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input0, 0)}, {shape}, concat.get()); + // Set attr + AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat); + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector{2}), concat); + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(2), concat); + AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat); + return concat; +} + +AnfNodePtr CreateBatchMatMul(const FuncGraphPtr &func_graph, const AnfNodePtr &lstm_input_grad, + const AnfNodePtr &concat) { + MS_EXCEPTION_IF_NULL(func_graph); + // Create node + std::vector matmul_inputs = {NewValueNode(std::make_shared(prim::kPrimBatchMatMul->name())), + concat, lstm_input_grad}; + auto batch_matmul = func_graph->NewCNode(matmul_inputs); + // Set infer data type and shape + auto concat_shape = AnfAlgo::GetOutputInferShape(concat, 0); + auto lstm_input_grad_shape = AnfAlgo::GetOutputInferShape(lstm_input_grad, 0); + std::vector shape = {concat_shape[0], concat_shape[2], lstm_input_grad_shape[2]}; + AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {shape}, batch_matmul.get()); + // Set attr + AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), batch_matmul); + AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(true), batch_matmul); + AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(false), batch_matmul); + return batch_matmul; +} + +AnfNodePtr CreateDwReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode, + const AnfNodePtr &batch_matmul) { + MS_EXCEPTION_IF_NULL(func_graph); + // Create node + std::vector reduce_sum_inputs = {NewValueNode(std::make_shared(prim::kPrimReduceSum->name())), + batch_matmul}; + auto reduce_sum = func_graph->NewCNode(reduce_sum_inputs); + // Set infer data type and shape + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)}, + {AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 0)}, reduce_sum.get()); + // Set attr + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector{0}), reduce_sum); + AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(false), reduce_sum); + AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum); + return reduce_sum; +} + +AnfNodePtr CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode, + const AnfNodePtr &lstm_input_grad) { + MS_EXCEPTION_IF_NULL(func_graph); + // Create node + std::vector reduce_sum_inputs = {NewValueNode(std::make_shared(prim::kPrimReduceSum->name())), + lstm_input_grad}; + auto reduce_sum = func_graph->NewCNode(reduce_sum_inputs); + // Set infer data type and shape + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 1)}, + {AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 1)}, reduce_sum.get()); + // Set attr + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector{0, 1}), reduce_sum); + AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(false), reduce_sum); + AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum); + return reduce_sum; +} +} // namespace + +const BaseRef DynamicRnnGradFissionV2::DefinePattern() const { + VarPtr Xs = std::make_shared(); + return VectorRef({prim::kPrimDynamicRNNGrad, Xs}); +} + +const AnfNodePtr DynamicRnnGradFissionV2::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto dynamic_rnn_grad_cnode = node->cast(); + MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode); + if (dynamic_rnn_grad_cnode->inputs().size() < kDynamicRNNGradInputNum + 1) { + MS_LOG(INFO) << "The node " << dynamic_rnn_grad_cnode->DebugString() << " has less than " + << kDynamicRNNGradInputNum + 1 << " inputs"; + return nullptr; + } + std::vector new_outputs; + auto lstm_input_grad = AddLSTMInputGradNode(func_graph, dynamic_rnn_grad_cnode, &new_outputs); + + size_t t_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(7), 0)[0]; + AnfNodePtr concat = nullptr; + if (t_size != 1) { + auto splitv = CreateSplitV(func_graph, dynamic_rnn_grad_cnode); + auto h_concat = CreateHConcat(func_graph, dynamic_rnn_grad_cnode, splitv); + concat = CreateConcat(func_graph, dynamic_rnn_grad_cnode, h_concat); + } else { + concat = CreateConcatNodeT1(func_graph, dynamic_rnn_grad_cnode); + } + + auto batch_matmul = CreateBatchMatMul(func_graph, lstm_input_grad, concat); + std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; + if (t_size != 1) { + auto dw_reduce_sum = CreateDwReduceSum(func_graph, dynamic_rnn_grad_cnode, batch_matmul); + make_tuple_inputs.emplace_back(dw_reduce_sum); + } else { + make_tuple_inputs.emplace_back(batch_matmul); + } + + // create reduce_sum_2 + auto db_reduce_sum = CreateDbReduceSum(func_graph, dynamic_rnn_grad_cnode, lstm_input_grad); + make_tuple_inputs.emplace_back(db_reduce_sum); + make_tuple_inputs.insert(make_tuple_inputs.end(), new_outputs.begin(), new_outputs.end()); + auto make_tuple = func_graph->NewCNode(make_tuple_inputs); + return make_tuple; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.h similarity index 77% rename from mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission.h rename to mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.h index 5c36c17e44..6bbeede489 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.h @@ -13,21 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_H_ + +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_V2_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_V2_H_ #include "backend/optimizer/common/optimizer.h" namespace mindspore { namespace opt { -class DynamicRNNGradFission : public PatternProcessPass { +class DynamicRnnGradFissionV2 : public PatternProcessPass { public: - explicit DynamicRNNGradFission(bool multigraph = true) : PatternProcessPass("dynamic_rnn_grad_fission", multigraph) {} - ~DynamicRNNGradFission() override = default; + explicit DynamicRnnGradFissionV2(bool multigraph = true) + : PatternProcessPass("dynamic_rnn_grad_fission_v2", multigraph) {} + ~DynamicRnnGradFissionV2() override = default; const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; }; } // namespace opt } // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_V2_H_ diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 83b6ff2fb6..4bdb1e661d 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -232,6 +232,9 @@ constexpr auto kSparseApplyFtrlName = "SparseApplyFtrl"; constexpr auto kSparseApplyFtrlV2Name = "SparseApplyFtrlV2"; constexpr auto kSGDName = "SGD"; constexpr auto kLARSUpdateName = "LARSUpdate"; +constexpr auto kBasicLSTMCellCStateGradOpName = "BasicLSTMCellCStateGrad"; +constexpr auto kBasicLSTMCellCStateGradV2OpName = "BasicLSTMCellCStateGradV2"; +constexpr auto kMatMulV2OpName = "MatMulV2"; // Hcom Op Type constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce"; diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 3fa5d0b896..05bc1f7a39 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -282,6 +282,7 @@ from .inv import _inv_tbe from .inv_grad import _inv_grad_tbe from .invert import _invert_tbe from .basic_lstm_cell import _basic_lstm_cell_tbe +from .basic_lstm_cell_c_state_grad_v2 import _basic_lstm_cell_c_state_grad_tbe_v2 from .basic_lstm_cell_c_state_grad import _basic_lstm_cell_c_state_grad_tbe from .basic_lstm_cell_weight_grad import _basic_lstm_cell_weight_grad_tbe from .basic_lstm_cell_input_grad import _basic_lstm_cell_input_grad_tbe diff --git a/mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py b/mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py new file mode 100644 index 0000000000..37dc160b58 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py @@ -0,0 +1,51 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""BasicLSTMCellCStateGradV2 op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +basic_lstm_cell_c_state_grad_op_info_v2 = TBERegOp("BasicLSTMCellCStateGradV2") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("basic_lstm_cell_c_state_grad.so") \ + .compute_cost(10) \ + .kernel_name("basic_lstm_cell_c_state_grad_v2") \ + .attr("forget_bias", "optional", "float", "all") \ + .attr("activation", "optional", "str", "all") \ + .partial_flag(True) \ + .input(0, "c", False, "required", "all") \ + .input(1, "dy", False, "required", "all") \ + .input(2, "dht", False, "required", "all") \ + .input(3, "dct", False, "required", "all") \ + .input(4, "it", False, "required", "all") \ + .input(5, "jt", False, "required", "all") \ + .input(6, "ft", False, "required", "all") \ + .input(7, "ot", False, "required", "all") \ + .input(8, "tanhct", False, "required", "all") \ + .output(0, "dgate", False, "required", "all") \ + .output(1, "dct_1", False, "required", "all") \ + .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, + DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, + DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \ + .get_op_info() + + +@op_info_register(basic_lstm_cell_c_state_grad_op_info_v2) +def _basic_lstm_cell_c_state_grad_tbe_v2(): + """BasicLSTMCellCStateGradV2 TBE register""" + return