| @@ -140,7 +140,8 @@ static std::map<string, string> tbe_func_adapter_map = { | |||
| {"i_fmr", "ifmr"}, | |||
| {"matrix_diag", "matrix_diag_d"}, | |||
| {"matrix_diag_part", "matrix_diag_part_d"}, | |||
| {"matrix_set_diag", "matrix_set_diag_d"}}; | |||
| {"matrix_set_diag", "matrix_set_diag_d"}, | |||
| {"l_stm_input_grad", "lstm_input_grad"}}; | |||
| void TbeAdapter::NormalizeFuncName(std::string *func_name) { | |||
| if (func_name == nullptr) { | |||
| @@ -150,7 +150,13 @@ bool TbeKernelJsonCreator::GenInputDescJson(const std::shared_ptr<AnfNode> &anf_ | |||
| MS_EXCEPTION_IF_NULL(input_ptr); | |||
| MS_EXCEPTION_IF_NULL(input_list); | |||
| std::string op_name = AnfAlgo::GetCNodeName(anf_node); | |||
| if (input_ptr->name() == "input_indices" && op_name == kTopKOpName) { | |||
| if (op_name == kDynamicRNNOpName && input_ptr->name() == "seq_length") { | |||
| nlohmann::json input_desc_json; | |||
| auto in_name = input_ptr->name(); | |||
| input_desc_json[kJName] = in_name + std::to_string(input_i); | |||
| input_desc_json[kJValid] = false; | |||
| input_list->emplace_back(input_desc_json); | |||
| } else if (input_ptr->name() == "input_indices" && op_name == kTopKOpName) { | |||
| TbeAdapter::GenTopKV2IndicesTensorInfo(anf_node, real_input_index, input_list, creater_type_); | |||
| } else { | |||
| auto dtype = GetDeviceInputType(anf_node, real_input_index); | |||
| @@ -19,6 +19,7 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission.h" | |||
| #include "backend/optimizer/ascend/ir_fission/bn_split.h" | |||
| #include "backend/optimizer/ascend/ir_fission/bn_grad_split.h" | |||
| #include "backend/optimizer/ascend/ir_fission/batch_norm_grad_split.h" | |||
| @@ -107,6 +108,7 @@ | |||
| #include "backend/optimizer/ascend/ir_fission/concat_fission.h" | |||
| #include "backend/optimizer/ascend/ir_fission/pack_fission.h" | |||
| #include "backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h" | |||
| #include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_rnn.h" | |||
| #include "utils/ms_context.h" | |||
| #include "backend/optimizer/graph_kernel/composite_ops_fusion.h" | |||
| #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" | |||
| @@ -278,6 +280,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap | |||
| } | |||
| ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicRNN>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<DynamicRNNGradFission>()); | |||
| AddAscendIRFusionRulesPass(ir_fusion_pm.get()); | |||
| AddAscendIRFusionPass(ir_fusion_pm.get()); | |||
| @@ -0,0 +1,77 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_rnn.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "utils/utils.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "base/core_ops.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const BaseRef InsertPlaceholderForDynamicRNN::DefinePattern() const { | |||
| std::shared_ptr<Var> V = std::make_shared<CondVar>(UnVisited); | |||
| std::shared_ptr<Var> Xs = std::make_shared<SeqVar>(); | |||
| return VectorRef({V, Xs}); | |||
| } | |||
| const AnfNodePtr InsertPlaceholderForDynamicRNN::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto op_name = AnfAlgo::GetCNodeName(cnode); | |||
| if (op_name != kDynamicRNNOpName) { | |||
| return nullptr; | |||
| } | |||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); | |||
| auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>(); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(node); | |||
| if (input_num == 0) { | |||
| return nullptr; | |||
| } | |||
| std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; | |||
| for (size_t in_idx = 0; in_idx < input_num; in_idx++) { | |||
| auto input_node = AnfAlgo::GetInputNode(cnode, in_idx); | |||
| if (in_idx == 3) { | |||
| auto value = std::make_shared<None>(); | |||
| auto value_node = NewValueNode(value); | |||
| value_node->set_abstract(std::make_shared<abstract::AbstractNone>()); | |||
| auto new_node = kernel_graph->NewValueNode(value_node); | |||
| kernel_graph->AddValueNodeToGraph(new_node); | |||
| new_inputs.push_back(new_node); | |||
| } | |||
| new_inputs.push_back(input_node); | |||
| } | |||
| CNodePtr new_node = nullptr; | |||
| if (kernel_graph == nullptr) { | |||
| new_node = std::make_shared<CNode>(*cnode); | |||
| } else { | |||
| new_node = kernel_graph->NewCNode(cnode); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(new_node); | |||
| new_node->set_inputs(new_inputs); | |||
| return new_node; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_ADD_PLACEHOLDER_FOR_DYNAMIC_RNN_H | |||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_ADD_PLACEHOLDER_FOR_DYNAMIC_RNN_H | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "backend/optimizer/ascend/ascend_helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class InsertPlaceholderForDynamicRNN : public PatternProcessPass { | |||
| public: | |||
| explicit InsertPlaceholderForDynamicRNN(bool multigraph = true) | |||
| : PatternProcessPass("add_placeholder_for_dynamic_rnn", multigraph) {} | |||
| ~InsertPlaceholderForDynamicRNN() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_ADD_PLACEHOLDER_FOR_DYNAMIC_RNN_H | |||
| @@ -0,0 +1,250 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| constexpr size_t kDynamicRNNGradInputNum = 16; | |||
| constexpr size_t kLSTMInputGradOutputNum = 4; | |||
| const BaseRef DynamicRNNGradFission::DefinePattern() const { | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| return VectorRef({prim::kPrimDynamicRNNGrad, Xs}); | |||
| } | |||
| AnfNodePtr CreateSplitVD(const FuncGraphPtr &graph, const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| // SplitV | |||
| std::vector<AnfNodePtr> splitvd_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())), node}; | |||
| auto split_vd = graph->NewCNode(splitvd_input); | |||
| MS_EXCEPTION_IF_NULL(split_vd); | |||
| auto dtypes = {AnfAlgo::GetOutputInferDataType(node, 0), AnfAlgo::GetOutputInferDataType(node, 0)}; | |||
| std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node, 0)[0] - 1, AnfAlgo::GetOutputInferShape(node, 0)[1], | |||
| AnfAlgo::GetOutputInferShape(node, 0)[2]}; | |||
| auto shape2 = {IntToSize(1), AnfAlgo::GetOutputInferShape(node, 0)[1], AnfAlgo::GetOutputInferShape(node, 0)[2]}; | |||
| std::vector<std::vector<size_t>> shapes = {shape, shape2}; | |||
| AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_vd.get()); | |||
| AnfAlgo::SetNodeAttr("split_dim", MakeValue(0), split_vd); | |||
| AnfAlgo::SetNodeAttr("num_split", MakeValue(2), split_vd); | |||
| int tmp = SizeToInt(AnfAlgo::GetOutputInferShape(node, 0)[0]) - 1; | |||
| AnfAlgo::SetNodeAttr("size_splits", MakeValue(std::vector<int>{tmp, 1}), split_vd); | |||
| AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), split_vd); | |||
| return split_vd; | |||
| } | |||
| AnfNodePtr CreateLSTMInputGrad(const FuncGraphPtr &graph, const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| const auto &dynamic_rnn_grad_inputs = cnode->inputs(); | |||
| std::vector<AnfNodePtr> lstm_input_grad_inputs = {NewValueNode(std::make_shared<Primitive>(kLSTMInputGradOpName)), | |||
| dynamic_rnn_grad_inputs[2], | |||
| dynamic_rnn_grad_inputs[6], | |||
| dynamic_rnn_grad_inputs[8], | |||
| dynamic_rnn_grad_inputs[9], | |||
| dynamic_rnn_grad_inputs[10], | |||
| dynamic_rnn_grad_inputs[11], | |||
| dynamic_rnn_grad_inputs[12], | |||
| dynamic_rnn_grad_inputs[13], | |||
| dynamic_rnn_grad_inputs[14], | |||
| dynamic_rnn_grad_inputs[15], | |||
| dynamic_rnn_grad_inputs[16]}; | |||
| std::vector<AnfNodePtr> ori_outputs; | |||
| CreateMultipleOutputsOfAnfNode(graph, node, 5, &ori_outputs); | |||
| auto lstm_op = graph->NewCNode(lstm_input_grad_inputs); | |||
| MS_EXCEPTION_IF_NULL(lstm_op); | |||
| auto ori_type = AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_inputs[8], 0); | |||
| auto types = {AnfAlgo::GetOutputInferDataType(ori_outputs[2], 0), AnfAlgo::GetOutputInferDataType(ori_outputs[3], 0), | |||
| AnfAlgo::GetOutputInferDataType(ori_outputs[4], 0), ori_type}; | |||
| std::vector<size_t> ori_shape = {AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_inputs[8], 0)[0], | |||
| AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_inputs[8], 0)[1], | |||
| 4 * AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_inputs[8], 0)[2]}; | |||
| auto shapes = {AnfAlgo::GetOutputInferShape(ori_outputs[2], 0), AnfAlgo::GetOutputInferShape(ori_outputs[3], 0), | |||
| AnfAlgo::GetOutputInferShape(ori_outputs[4], 0), ori_shape}; | |||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, lstm_op.get()); | |||
| return lstm_op; | |||
| } | |||
| AnfNodePtr CreateBatchMatMul(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node1); | |||
| MS_EXCEPTION_IF_NULL(node2); | |||
| // BatchMatMul | |||
| std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMul->name())), | |||
| node2, node1}; | |||
| auto batch_matmul = graph->NewCNode(matmul_inputs); | |||
| MS_EXCEPTION_IF_NULL(batch_matmul); | |||
| auto types = {AnfAlgo::GetOutputInferDataType(node1, 0)}; | |||
| std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node2, 0)[0], AnfAlgo::GetOutputInferShape(node2, 0)[2], | |||
| AnfAlgo::GetOutputInferShape(node1, 0)[2]}; | |||
| auto shapes = {shape}; | |||
| AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), batch_matmul); | |||
| AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(true), batch_matmul); | |||
| AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(false), batch_matmul); | |||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, batch_matmul.get()); | |||
| return batch_matmul; | |||
| } | |||
| AnfNodePtr AddHConcatD(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node1); | |||
| MS_EXCEPTION_IF_NULL(node2); | |||
| std::vector<AnfNodePtr> ori_outputs; | |||
| CreateMultipleOutputsOfAnfNode(graph, node2, 2, &ori_outputs); | |||
| auto ori_shape = AnfAlgo::GetOutputInferShape(node1, 0); | |||
| std::vector<std::vector<size_t>> shape_tmp; | |||
| if (ori_shape.size() == 3) { | |||
| shape_tmp = {ori_shape}; | |||
| } else { | |||
| shape_tmp = {{IntToSize(1), ori_shape[0], ori_shape[1]}}; | |||
| } | |||
| auto ori_dtype = {AnfAlgo::GetOutputInferDataType(node1, 0)}; | |||
| // reshape | |||
| std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())), | |||
| node1}; | |||
| auto reshape = graph->NewCNode(reshape_input); | |||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape); | |||
| AnfAlgo::SetOutputInferTypeAndShape(ori_dtype, shape_tmp, reshape.get()); | |||
| // concatd --> concat | |||
| std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())), | |||
| reshape, ori_outputs[0]}; | |||
| auto concat_op = graph->NewCNode(concat_inputs); | |||
| MS_EXCEPTION_IF_NULL(concat_op); | |||
| std::vector<size_t> input = {AnfAlgo::GetOutputInferShape(node2, 0)[0] + 1, AnfAlgo::GetOutputInferShape(node2, 0)[1], | |||
| AnfAlgo::GetOutputInferShape(node2, 0)[2]}; | |||
| auto types = {AnfAlgo::GetOutputInferDataType(node1, 0)}; | |||
| auto shapes = {input}; | |||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, concat_op.get()); | |||
| AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat_op); | |||
| AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{2}), concat_op); | |||
| AnfAlgo::SetNodeAttr("axis", MakeValue(0), concat_op); | |||
| AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat_op); | |||
| return concat_op; | |||
| } | |||
| AnfNodePtr AddConcatD(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node1); | |||
| MS_EXCEPTION_IF_NULL(node2); | |||
| // concatd --> concat | |||
| std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())), node1, | |||
| node2}; | |||
| auto concat_op = graph->NewCNode(concat_inputs); | |||
| MS_EXCEPTION_IF_NULL(concat_op); | |||
| std::vector<size_t> input = {AnfAlgo::GetOutputInferShape(node1, 0)[0], AnfAlgo::GetOutputInferShape(node1, 0)[1], | |||
| AnfAlgo::GetOutputInferShape(node1, 0)[2] + AnfAlgo::GetOutputInferShape(node2, 0)[2]}; | |||
| auto types = {AnfAlgo::GetOutputInferDataType(node1, 0)}; | |||
| auto shapes = {input}; | |||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, concat_op.get()); | |||
| AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat_op); | |||
| AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{2}), concat_op); | |||
| AnfAlgo::SetNodeAttr("axis", MakeValue(2), concat_op); | |||
| AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat_op); | |||
| return concat_op; | |||
| } | |||
| AnfNodePtr AddDwReduceSum(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { | |||
| // node1 : dynamic output | |||
| // node2 : matmul | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node1); | |||
| MS_EXCEPTION_IF_NULL(node2); | |||
| std::vector<AnfNodePtr> ori_outputs; | |||
| CreateMultipleOutputsOfAnfNode(graph, node1, 5, &ori_outputs); | |||
| // ReduceSumd | |||
| std::vector<AnfNodePtr> reducesum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())), | |||
| node2}; | |||
| auto reduce_sumd = graph->NewCNode(reducesum_inputs); | |||
| MS_EXCEPTION_IF_NULL(reduce_sumd); | |||
| auto types = {AnfAlgo::GetOutputInferDataType(ori_outputs[0], 0)}; | |||
| auto shapes = {AnfAlgo::GetOutputInferShape(ori_outputs[0], 0)}; | |||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, reduce_sumd.get()); | |||
| AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int>{0}), reduce_sumd); | |||
| AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_sumd); | |||
| AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sumd); | |||
| return reduce_sumd; | |||
| } | |||
| AnfNodePtr AddDbReduceSum(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { | |||
| // node1 lstm output | |||
| // node2 // dynamic output | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node1); | |||
| MS_EXCEPTION_IF_NULL(node2); | |||
| std::vector<AnfNodePtr> ori_outputs; | |||
| CreateMultipleOutputsOfAnfNode(graph, node2, 5, &ori_outputs); | |||
| // ReduceSumd --> ReduceSum | |||
| std::vector<AnfNodePtr> reducerum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())), | |||
| node1}; | |||
| auto reduce_sumd = graph->NewCNode(reducerum_inputs); | |||
| MS_EXCEPTION_IF_NULL(reduce_sumd); | |||
| auto types = {AnfAlgo::GetOutputInferDataType(ori_outputs[1], 0)}; | |||
| auto shapes = {AnfAlgo::GetOutputInferShape(ori_outputs[1], 0)}; | |||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, reduce_sumd.get()); | |||
| AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int>{0, 1}), reduce_sumd); | |||
| AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_sumd); | |||
| AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sumd); | |||
| return reduce_sumd; | |||
| } | |||
| const AnfNodePtr DynamicRNNGradFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->size() < kDynamicRNNGradInputNum + 1) { | |||
| MS_LOG(INFO) << "The input num of DynamicRNNGrad less than" << kDynamicRNNGradInputNum | |||
| << ". The node should not be changed"; | |||
| return nullptr; | |||
| } | |||
| // input_list of dynamic_rnn_grad | |||
| const auto &ori_inputs = cnode->inputs(); | |||
| // create split_vd | |||
| auto split_vd = CreateSplitVD(func_graph, ori_inputs[7]); | |||
| // create concat_1 | |||
| auto h_concat = AddHConcatD(func_graph, ori_inputs[5], split_vd); | |||
| // create concat_2 | |||
| auto concat = AddConcatD(func_graph, ori_inputs[1], h_concat); | |||
| // create lsym_input_grad | |||
| auto lstm_input_grad = CreateLSTMInputGrad(func_graph, cnode); | |||
| std::vector<AnfNodePtr> lstm_outputs; | |||
| CreateMultipleOutputsOfAnfNode(func_graph, lstm_input_grad, kLSTMInputGradOutputNum, &lstm_outputs); | |||
| // create matmul | |||
| auto batch_matmul = CreateBatchMatMul(func_graph, lstm_outputs[3], concat); | |||
| // create reduce_sum_1 | |||
| auto dw_reduce_sum = AddDwReduceSum(func_graph, node, batch_matmul); | |||
| // create reduce_sum_2 | |||
| auto db_reduce_sum = AddDbReduceSum(func_graph, lstm_outputs[3], node); | |||
| std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), | |||
| dw_reduce_sum, | |||
| db_reduce_sum, | |||
| lstm_outputs[0], | |||
| lstm_outputs[1], | |||
| lstm_outputs[2]}; | |||
| auto make_tuple = func_graph->NewCNode(make_tuple_inputs); | |||
| MS_EXCEPTION_IF_NULL(make_tuple); | |||
| return make_tuple; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,33 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_H_ | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class DynamicRNNGradFission : public PatternProcessPass { | |||
| public: | |||
| explicit DynamicRNNGradFission(bool multigraph = true) : PatternProcessPass("dynamic_rnn_grad_fission", multigraph) {} | |||
| ~DynamicRNNGradFission() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_H_ | |||
| @@ -21,11 +21,10 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const std::set<std::pair<string, string>> invalid_formats_pair = {{kOpFormat_C1HWNCoC0, kOpFormat_NCHW}, | |||
| {kOpFormat_NCHW, kOpFormat_C1HWNCoC0}, | |||
| {kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT}, | |||
| {kOpFormat_DEFAULT, kOpFormat_FRACTAL_ZN_LSTM}, | |||
| {kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}}; | |||
| const std::set<std::pair<string, string>> invalid_formats_pair = { | |||
| {kOpFormat_C1HWNCoC0, kOpFormat_NCHW}, {kOpFormat_NCHW, kOpFormat_C1HWNCoC0}, | |||
| {kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT}, {kOpFormat_DEFAULT, kOpFormat_FRACTAL_ZN_LSTM}, | |||
| {kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_DEFAULT}, {kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}}; | |||
| bool TransDataSplit::Run(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| @@ -83,6 +82,9 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n | |||
| new_transpose_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0), kernel_select_, | |||
| false, prim::kPrimTranspose->name()); | |||
| AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{2, 3, 1, 0}), new_transpose_node); | |||
| if (output_format == kOpFormat_FRACTAL_ZN_LSTM) { | |||
| AnfAlgo::SetNodeAttr("nop_op", MakeValue(true), new_transpose_node); | |||
| } | |||
| RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transpose_node); | |||
| // trans hwcn to output_format | |||
| @@ -404,7 +404,11 @@ bool IsNopNode(const AnfNodePtr &node) { | |||
| } | |||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (nop_nodes.find(AnfAlgo::GetCNodeName(cnode)) == nop_nodes.end()) { | |||
| bool is_nop_node = false; | |||
| if (AnfAlgo::HasNodeAttr("nop_op", cnode)) { | |||
| is_nop_node = AnfAlgo::GetNodeAttr<bool>(cnode, "nop_op"); | |||
| } | |||
| if (nop_nodes.find(AnfAlgo::GetCNodeName(cnode)) == nop_nodes.end() && !is_nop_node) { | |||
| return false; | |||
| } | |||
| return true; | |||
| @@ -52,8 +52,12 @@ const int kUnSupportMixedDataTypeIndex = -1; | |||
| bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| // Check input data type | |||
| auto name = AnfAlgo::GetCNodeName(cnode); | |||
| for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) { | |||
| TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); | |||
| if (name == kDynamicRNNOpName && input_origin_type == kMetaTypeNone) { | |||
| continue; | |||
| } | |||
| if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) { | |||
| return false; | |||
| } | |||
| @@ -478,6 +482,9 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get()); | |||
| continue; | |||
| } | |||
| if (selected_kernel_info.GetInputFormat(input_index) == kOpFormat_FRACTAL_ZN_LSTM) { | |||
| continue; | |||
| } | |||
| // we set special device info of a input tensor. | |||
| bool is_ref = false; | |||
| auto op_info = kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE); | |||
| @@ -127,8 +127,12 @@ bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_i | |||
| auto kernel_mod = AnfAlgo::GetKernelMod(anf_node_ptr); | |||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||
| kernel_mod->set_kernel_name(anf_node_ptr->fullname_with_scope()); | |||
| auto op_name = AnfAlgo::GetCNodeName(anf_node_ptr); | |||
| if (AnfAlgo::GetCNodeName(anf_node_ptr) != kAtomicAddrCleanOpName) { | |||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node_ptr); ++i) { | |||
| if (op_name == kDynamicRNNOpName && i == 3) { | |||
| continue; | |||
| } | |||
| auto real_input_index = AnfAlgo::GetRealInputIndex(anf_node_ptr, i); | |||
| auto device_address = AnfAlgo::GetPrevNodeOutputAddr(anf_node_ptr, real_input_index); | |||
| AddressPtr input = std::make_shared<Address>(); | |||
| @@ -219,6 +219,8 @@ constexpr auto kFusedScaleApplyMomentum = "FusedScaleApplyMomentum"; | |||
| constexpr auto kBasicLSTMCellWeightGradOpName = "BasicLSTMCellWeightGrad"; | |||
| constexpr auto kBasicLSTMCellInputGradOpName = "BasicLSTMCellInputGrad"; | |||
| constexpr auto kBasicLSTMCellOpName = "BasicLSTMCell"; | |||
| constexpr auto kDynamicRNNOpName = "DynamicRNN"; | |||
| constexpr auto kLSTMInputGradOpName = "LSTMInputGrad"; | |||
| // attr key name | |||
| constexpr auto kAttrInputNames = "input_names"; | |||
| @@ -105,6 +105,8 @@ inline const PrimitivePtr kPrimArgMaxWithValue = std::make_shared<Primitive>("Ar | |||
| inline const PrimitivePtr kPrimUnique = std::make_shared<Primitive>("Unique"); | |||
| inline const PrimitivePtr kPrimUniqueGrad = std::make_shared<Primitive>("UniqueGrad"); | |||
| inline const PrimitivePtr kPrimExtractImagePatches = std::make_shared<Primitive>("ExtractImagePatches"); | |||
| inline const PrimitivePtr kPrimDynamicRNN = std::make_shared<Primitive>("DynamicRNN"); | |||
| inline const PrimitivePtr kPrimDynamicRNNGrad = std::make_shared<Primitive>("DynamicRNNGrad"); | |||
| // NN | |||
| inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | |||
| @@ -214,6 +216,7 @@ inline const PrimitivePtr kPrimRound = std::make_shared<Primitive>("Round"); | |||
| inline const PrimitivePtr kPrimExp = std::make_shared<Primitive>("Exp"); | |||
| inline const PrimitivePtr kPrimLog = std::make_shared<Primitive>("Log"); | |||
| inline const PrimitivePtr kPrimRsqrt = std::make_shared<Primitive>("Rsqrt"); | |||
| inline const PrimitivePtr kPrimSplitV = std::make_shared<Primitive>("SplitV"); | |||
| // Statements | |||
| inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return"); | |||