| @@ -19,7 +19,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include "backend/optimizer/common/optimizer.h" | #include "backend/optimizer/common/optimizer.h" | ||||
| #include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission.h" | |||||
| #include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.h" | |||||
| #include "backend/optimizer/ascend/ir_fission/bn_split.h" | #include "backend/optimizer/ascend/ir_fission/bn_split.h" | ||||
| #include "backend/optimizer/ascend/ir_fission/bn_grad_split.h" | #include "backend/optimizer/ascend/ir_fission/bn_grad_split.h" | ||||
| #include "backend/optimizer/ascend/ir_fission/batch_norm_grad_split.h" | #include "backend/optimizer/ascend/ir_fission/batch_norm_grad_split.h" | ||||
| @@ -61,6 +61,7 @@ | |||||
| #include "backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h" | #include "backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h" | ||||
| #include "backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h" | #include "backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h" | ||||
| #include "backend/optimizer/ascend/format_type/insert_trans_op.h" | #include "backend/optimizer/ascend/format_type/insert_trans_op.h" | ||||
| #include "backend/optimizer/ascend/format_type/dynamic_rnn_grad_reformat.h" | |||||
| #include "backend/optimizer/ascend/format_type/insert_transpose_for_basiclstm_op.h" | #include "backend/optimizer/ascend/format_type/insert_transpose_for_basiclstm_op.h" | ||||
| #include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h" | #include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h" | ||||
| #include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h" | #include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h" | ||||
| @@ -215,6 +216,7 @@ void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) | |||||
| auto optimizer = std::make_shared<GraphOptimizer>(); | auto optimizer = std::make_shared<GraphOptimizer>(); | ||||
| auto data_layout_pm = std::make_shared<PassManager>("transop_pm"); | auto data_layout_pm = std::make_shared<PassManager>("transop_pm"); | ||||
| data_layout_pm->AddPass(std::make_shared<RectifyDoMaskKernelInfo>()); | data_layout_pm->AddPass(std::make_shared<RectifyDoMaskKernelInfo>()); | ||||
| data_layout_pm->AddPass(std::make_shared<DynamicRNNGradReformat>()); | |||||
| data_layout_pm->AddPass(std::make_shared<InsertTransOp>()); | data_layout_pm->AddPass(std::make_shared<InsertTransOp>()); | ||||
| data_layout_pm->AddPass(std::make_shared<GetitemTuple>()); | data_layout_pm->AddPass(std::make_shared<GetitemTuple>()); | ||||
| data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); | data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); | ||||
| @@ -276,7 +278,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap | |||||
| ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>()); | ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); | ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicRNN>()); | ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicRNN>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<DynamicRNNGradFission>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<DynamicRnnGradFissionV2>()); | |||||
| AddAscendIRFusionRulesPass(ir_fusion_pm.get()); | AddAscendIRFusionRulesPass(ir_fusion_pm.get()); | ||||
| AddAscendIRFusionPass(ir_fusion_pm.get()); | AddAscendIRFusionPass(ir_fusion_pm.get()); | ||||
| @@ -0,0 +1,80 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/ascend/format_type/dynamic_rnn_grad_reformat.h" | |||||
| #include <memory> | |||||
| #include "backend/optimizer/ascend/ascend_helper.h" | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "utils/utils.h" | |||||
| #include "base/core_ops.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| const BaseRef DynamicRNNGradReformat::DefinePattern() const { | |||||
| VarPtr Xs = std::make_shared<Var>(); | |||||
| VarPtr Xs2 = std::make_shared<Var>(); | |||||
| MS_EXCEPTION_IF_NULL(Xs); | |||||
| MS_EXCEPTION_IF_NULL(Xs2); | |||||
| const auto split = std::make_shared<Primitive>(prim::kPrimSplitV->name()); | |||||
| return VectorRef({split, VectorRef({std::make_shared<Primitive>(prim::kPrimMatMul->name()), Xs, Xs2})}); | |||||
| } | |||||
| const AnfNodePtr DynamicRNNGradReformat::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const EquivPtr &) const { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto split_v = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(split_v); | |||||
| auto matmul = CheckAnfNodeIfCNodeAndInputSize(split_v->input(1), 3); | |||||
| MS_EXCEPTION_IF_NULL(matmul); | |||||
| auto input_node_with_idx = AnfAlgo::GetPrevNodeOutput(matmul, 0); | |||||
| auto input_node = input_node_with_idx.first; | |||||
| MS_EXCEPTION_IF_NULL(input_node); | |||||
| if (!(input_node->isa<CNode>() && | |||||
| AnfAlgo::GetCNodeName(input_node->cast<CNodePtr>()) == kBasicLSTMCellCStateGradV2OpName)) { | |||||
| return nullptr; | |||||
| } | |||||
| // reformat matmul | |||||
| auto matmul_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(matmul); | |||||
| MS_EXCEPTION_IF_NULL(matmul_kernel_build_info); | |||||
| auto matmul_new_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||||
| matmul_new_builder->SetInputsFormat({kOpFormat_FRAC_NZ, kOpFormat_FRAC_NZ}); | |||||
| matmul_new_builder->SetOutputsFormat({kOpFormat_FRAC_NZ}); | |||||
| matmul_new_builder->SetInputsDeviceType({kNumberTypeFloat16, kNumberTypeFloat16}); | |||||
| matmul_new_builder->SetOutputsDeviceType({kNumberTypeFloat}); | |||||
| matmul_new_builder->SetKernelType(matmul_kernel_build_info->kernel_type()); | |||||
| matmul_new_builder->SetFusionType(matmul_kernel_build_info->fusion_type()); | |||||
| matmul_new_builder->SetProcessor(matmul_kernel_build_info->processor()); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(matmul_new_builder->Build(), matmul.get()); | |||||
| AnfAlgo::SetNodeAttr("insert_backend", MakeValue(true), matmul); | |||||
| // reformat split_v | |||||
| auto split_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(split_v); | |||||
| MS_EXCEPTION_IF_NULL(split_kernel_build_info); | |||||
| auto split_new_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||||
| split_new_builder->SetInputsFormat({kOpFormat_FRAC_NZ}); | |||||
| split_new_builder->SetOutputsFormat({kOpFormat_FRAC_NZ, kOpFormat_FRAC_NZ}); | |||||
| split_new_builder->SetInputsDeviceType(split_kernel_build_info->GetAllInputDeviceTypes()); | |||||
| split_new_builder->SetOutputsDeviceType(split_kernel_build_info->GetAllOutputDeviceTypes()); | |||||
| split_new_builder->SetKernelType(split_kernel_build_info->kernel_type()); | |||||
| split_new_builder->SetFusionType(split_kernel_build_info->fusion_type()); | |||||
| split_new_builder->SetProcessor(split_kernel_build_info->processor()); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(split_new_builder->Build(), split_v.get()); | |||||
| AnfAlgo::SetNodeAttr("insert_backend", MakeValue(true), split_v); | |||||
| return split_v; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,41 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DYNAMIC_RNN_GRAD_REFORMAT_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DYNAMIC_RNN_GRAD_REFORMAT_H_ | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <memory> | |||||
| #include "ir/anf.h" | |||||
| #include "backend/optimizer/common/pattern_engine.h" | |||||
| #include "backend/optimizer/common/helper.h" | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class DynamicRNNGradReformat : public PatternProcessPass { | |||||
| public: | |||||
| explicit DynamicRNNGradReformat(bool multigraph = true) | |||||
| : PatternProcessPass("dynamic_rnn_grad_reformat", multigraph) {} | |||||
| ~DynamicRNNGradReformat() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DYNAMIC_RNN_GRAD_REFORMAT_H_ | |||||
| @@ -1,250 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <algorithm> | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "backend/optimizer/common/helper.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| constexpr size_t kDynamicRNNGradInputNum = 16; | |||||
| constexpr size_t kLSTMInputGradOutputNum = 4; | |||||
| const BaseRef DynamicRNNGradFission::DefinePattern() const { | |||||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||||
| return VectorRef({prim::kPrimDynamicRNNGrad, Xs}); | |||||
| } | |||||
| AnfNodePtr CreateSplitVD(const FuncGraphPtr &graph, const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| // SplitV | |||||
| std::vector<AnfNodePtr> splitvd_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())), node}; | |||||
| auto split_vd = graph->NewCNode(splitvd_input); | |||||
| MS_EXCEPTION_IF_NULL(split_vd); | |||||
| auto dtypes = {AnfAlgo::GetOutputInferDataType(node, 0), AnfAlgo::GetOutputInferDataType(node, 0)}; | |||||
| std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node, 0)[0] - 1, AnfAlgo::GetOutputInferShape(node, 0)[1], | |||||
| AnfAlgo::GetOutputInferShape(node, 0)[2]}; | |||||
| auto shape2 = {IntToSize(1), AnfAlgo::GetOutputInferShape(node, 0)[1], AnfAlgo::GetOutputInferShape(node, 0)[2]}; | |||||
| std::vector<std::vector<size_t>> shapes = {shape, shape2}; | |||||
| AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_vd.get()); | |||||
| AnfAlgo::SetNodeAttr("split_dim", MakeValue(0), split_vd); | |||||
| AnfAlgo::SetNodeAttr("num_split", MakeValue(2), split_vd); | |||||
| int tmp = SizeToInt(AnfAlgo::GetOutputInferShape(node, 0)[0]) - 1; | |||||
| AnfAlgo::SetNodeAttr("size_splits", MakeValue(std::vector<int>{tmp, 1}), split_vd); | |||||
| AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), split_vd); | |||||
| return split_vd; | |||||
| } | |||||
| AnfNodePtr CreateLSTMInputGrad(const FuncGraphPtr &graph, const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| const auto &dynamic_rnn_grad_inputs = cnode->inputs(); | |||||
| std::vector<AnfNodePtr> lstm_input_grad_inputs = {NewValueNode(std::make_shared<Primitive>(kLSTMInputGradOpName)), | |||||
| dynamic_rnn_grad_inputs[2], | |||||
| dynamic_rnn_grad_inputs[6], | |||||
| dynamic_rnn_grad_inputs[8], | |||||
| dynamic_rnn_grad_inputs[9], | |||||
| dynamic_rnn_grad_inputs[10], | |||||
| dynamic_rnn_grad_inputs[11], | |||||
| dynamic_rnn_grad_inputs[12], | |||||
| dynamic_rnn_grad_inputs[13], | |||||
| dynamic_rnn_grad_inputs[14], | |||||
| dynamic_rnn_grad_inputs[15], | |||||
| dynamic_rnn_grad_inputs[16]}; | |||||
| std::vector<AnfNodePtr> ori_outputs; | |||||
| CreateMultipleOutputsOfAnfNode(graph, node, 5, &ori_outputs); | |||||
| auto lstm_op = graph->NewCNode(lstm_input_grad_inputs); | |||||
| MS_EXCEPTION_IF_NULL(lstm_op); | |||||
| auto ori_type = AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_inputs[8], 0); | |||||
| auto types = {AnfAlgo::GetOutputInferDataType(ori_outputs[2], 0), AnfAlgo::GetOutputInferDataType(ori_outputs[3], 0), | |||||
| AnfAlgo::GetOutputInferDataType(ori_outputs[4], 0), ori_type}; | |||||
| std::vector<size_t> ori_shape = {AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_inputs[8], 0)[0], | |||||
| AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_inputs[8], 0)[1], | |||||
| 4 * AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_inputs[8], 0)[2]}; | |||||
| auto shapes = {AnfAlgo::GetOutputInferShape(ori_outputs[2], 0), AnfAlgo::GetOutputInferShape(ori_outputs[3], 0), | |||||
| AnfAlgo::GetOutputInferShape(ori_outputs[4], 0), ori_shape}; | |||||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, lstm_op.get()); | |||||
| return lstm_op; | |||||
| } | |||||
| AnfNodePtr CreateBatchMatMul(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node1); | |||||
| MS_EXCEPTION_IF_NULL(node2); | |||||
| // BatchMatMul | |||||
| std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMul->name())), | |||||
| node2, node1}; | |||||
| auto batch_matmul = graph->NewCNode(matmul_inputs); | |||||
| MS_EXCEPTION_IF_NULL(batch_matmul); | |||||
| auto types = {AnfAlgo::GetOutputInferDataType(node1, 0)}; | |||||
| std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node2, 0)[0], AnfAlgo::GetOutputInferShape(node2, 0)[2], | |||||
| AnfAlgo::GetOutputInferShape(node1, 0)[2]}; | |||||
| auto shapes = {shape}; | |||||
| AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), batch_matmul); | |||||
| AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(true), batch_matmul); | |||||
| AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(false), batch_matmul); | |||||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, batch_matmul.get()); | |||||
| return batch_matmul; | |||||
| } | |||||
| AnfNodePtr AddHConcatD(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node1); | |||||
| MS_EXCEPTION_IF_NULL(node2); | |||||
| std::vector<AnfNodePtr> ori_outputs; | |||||
| CreateMultipleOutputsOfAnfNode(graph, node2, 2, &ori_outputs); | |||||
| auto ori_shape = AnfAlgo::GetOutputInferShape(node1, 0); | |||||
| std::vector<std::vector<size_t>> shape_tmp; | |||||
| if (ori_shape.size() == 3) { | |||||
| shape_tmp = {ori_shape}; | |||||
| } else { | |||||
| shape_tmp = {{IntToSize(1), ori_shape[0], ori_shape[1]}}; | |||||
| } | |||||
| auto ori_dtype = {AnfAlgo::GetOutputInferDataType(node1, 0)}; | |||||
| // reshape | |||||
| std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())), | |||||
| node1}; | |||||
| auto reshape = graph->NewCNode(reshape_input); | |||||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape); | |||||
| AnfAlgo::SetOutputInferTypeAndShape(ori_dtype, shape_tmp, reshape.get()); | |||||
| // concatd --> concat | |||||
| std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())), | |||||
| reshape, ori_outputs[0]}; | |||||
| auto concat_op = graph->NewCNode(concat_inputs); | |||||
| MS_EXCEPTION_IF_NULL(concat_op); | |||||
| std::vector<size_t> input = {AnfAlgo::GetOutputInferShape(node2, 0)[0] + 1, AnfAlgo::GetOutputInferShape(node2, 0)[1], | |||||
| AnfAlgo::GetOutputInferShape(node2, 0)[2]}; | |||||
| auto types = {AnfAlgo::GetOutputInferDataType(node1, 0)}; | |||||
| auto shapes = {input}; | |||||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, concat_op.get()); | |||||
| AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat_op); | |||||
| AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{2}), concat_op); | |||||
| AnfAlgo::SetNodeAttr("axis", MakeValue(0), concat_op); | |||||
| AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat_op); | |||||
| return concat_op; | |||||
| } | |||||
| AnfNodePtr AddConcatD(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node1); | |||||
| MS_EXCEPTION_IF_NULL(node2); | |||||
| // concatd --> concat | |||||
| std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())), node1, | |||||
| node2}; | |||||
| auto concat_op = graph->NewCNode(concat_inputs); | |||||
| MS_EXCEPTION_IF_NULL(concat_op); | |||||
| std::vector<size_t> input = {AnfAlgo::GetOutputInferShape(node1, 0)[0], AnfAlgo::GetOutputInferShape(node1, 0)[1], | |||||
| AnfAlgo::GetOutputInferShape(node1, 0)[2] + AnfAlgo::GetOutputInferShape(node2, 0)[2]}; | |||||
| auto types = {AnfAlgo::GetOutputInferDataType(node1, 0)}; | |||||
| auto shapes = {input}; | |||||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, concat_op.get()); | |||||
| AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat_op); | |||||
| AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{2}), concat_op); | |||||
| AnfAlgo::SetNodeAttr("axis", MakeValue(2), concat_op); | |||||
| AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat_op); | |||||
| return concat_op; | |||||
| } | |||||
| AnfNodePtr AddDwReduceSum(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { | |||||
| // node1 : dynamic output | |||||
| // node2 : matmul | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node1); | |||||
| MS_EXCEPTION_IF_NULL(node2); | |||||
| std::vector<AnfNodePtr> ori_outputs; | |||||
| CreateMultipleOutputsOfAnfNode(graph, node1, 5, &ori_outputs); | |||||
| // ReduceSumd | |||||
| std::vector<AnfNodePtr> reducesum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())), | |||||
| node2}; | |||||
| auto reduce_sumd = graph->NewCNode(reducesum_inputs); | |||||
| MS_EXCEPTION_IF_NULL(reduce_sumd); | |||||
| auto types = {AnfAlgo::GetOutputInferDataType(ori_outputs[0], 0)}; | |||||
| auto shapes = {AnfAlgo::GetOutputInferShape(ori_outputs[0], 0)}; | |||||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, reduce_sumd.get()); | |||||
| AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int>{0}), reduce_sumd); | |||||
| AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_sumd); | |||||
| AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sumd); | |||||
| return reduce_sumd; | |||||
| } | |||||
| AnfNodePtr AddDbReduceSum(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { | |||||
| // node1 lstm output | |||||
| // node2 // dynamic output | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node1); | |||||
| MS_EXCEPTION_IF_NULL(node2); | |||||
| std::vector<AnfNodePtr> ori_outputs; | |||||
| CreateMultipleOutputsOfAnfNode(graph, node2, 5, &ori_outputs); | |||||
| // ReduceSumd --> ReduceSum | |||||
| std::vector<AnfNodePtr> reducerum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())), | |||||
| node1}; | |||||
| auto reduce_sumd = graph->NewCNode(reducerum_inputs); | |||||
| MS_EXCEPTION_IF_NULL(reduce_sumd); | |||||
| auto types = {AnfAlgo::GetOutputInferDataType(ori_outputs[1], 0)}; | |||||
| auto shapes = {AnfAlgo::GetOutputInferShape(ori_outputs[1], 0)}; | |||||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, reduce_sumd.get()); | |||||
| AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int>{0, 1}), reduce_sumd); | |||||
| AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_sumd); | |||||
| AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sumd); | |||||
| return reduce_sumd; | |||||
| } | |||||
| const AnfNodePtr DynamicRNNGradFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const EquivPtr &) const { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (cnode->size() < kDynamicRNNGradInputNum + 1) { | |||||
| MS_LOG(INFO) << "The input num of DynamicRNNGrad less than" << kDynamicRNNGradInputNum | |||||
| << ". The node should not be changed"; | |||||
| return nullptr; | |||||
| } | |||||
| // input_list of dynamic_rnn_grad | |||||
| const auto &ori_inputs = cnode->inputs(); | |||||
| // create split_vd | |||||
| auto split_vd = CreateSplitVD(func_graph, ori_inputs[7]); | |||||
| // create concat_1 | |||||
| auto h_concat = AddHConcatD(func_graph, ori_inputs[5], split_vd); | |||||
| // create concat_2 | |||||
| auto concat = AddConcatD(func_graph, ori_inputs[1], h_concat); | |||||
| // create lsym_input_grad | |||||
| auto lstm_input_grad = CreateLSTMInputGrad(func_graph, cnode); | |||||
| std::vector<AnfNodePtr> lstm_outputs; | |||||
| CreateMultipleOutputsOfAnfNode(func_graph, lstm_input_grad, kLSTMInputGradOutputNum, &lstm_outputs); | |||||
| // create matmul | |||||
| auto batch_matmul = CreateBatchMatMul(func_graph, lstm_outputs[3], concat); | |||||
| // create reduce_sum_1 | |||||
| auto dw_reduce_sum = AddDwReduceSum(func_graph, node, batch_matmul); | |||||
| // create reduce_sum_2 | |||||
| auto db_reduce_sum = AddDbReduceSum(func_graph, lstm_outputs[3], node); | |||||
| std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), | |||||
| dw_reduce_sum, | |||||
| db_reduce_sum, | |||||
| lstm_outputs[0], | |||||
| lstm_outputs[1], | |||||
| lstm_outputs[2]}; | |||||
| auto make_tuple = func_graph->NewCNode(make_tuple_inputs); | |||||
| MS_EXCEPTION_IF_NULL(make_tuple); | |||||
| return make_tuple; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,483 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| constexpr size_t kDynamicRNNGradInputNum = 16; | |||||
| constexpr size_t kSplitVOutputNum = 2; | |||||
| constexpr size_t kLSTMInputGradOutputNum = 4; | |||||
| void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode, | |||||
| std::vector<std::vector<AnfNodePtr>> *result_nodes) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode); | |||||
| MS_EXCEPTION_IF_NULL(result_nodes); | |||||
| std::vector<AnfNodePtr> basic_lstm_cell_c_state_grad_nodes; | |||||
| std::vector<AnfNodePtr> matmul_nodes; | |||||
| std::vector<AnfNodePtr> split_nodes; | |||||
| // Get the size of t | |||||
| auto origin_input9_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(10), 0); | |||||
| size_t t_size = origin_input9_shape[0]; | |||||
| auto input_i_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(12), 0); | |||||
| for (size_t i = 0; i < t_size; ++i) { | |||||
| // Create basic_lstm_cell_c_state_grad | |||||
| std::vector<AnfNodePtr> basic_lstm_cell_c_state_grad_inputs = { | |||||
| NewValueNode(std::make_shared<Primitive>(kBasicLSTMCellCStateGradV2OpName))}; | |||||
| auto basic_lstm_cell_c_state_grad = func_graph->NewCNode(basic_lstm_cell_c_state_grad_inputs); | |||||
| std::vector<size_t> output0_dims{origin_input9_shape[0], 4 * (((origin_input9_shape[1] + 15) / 16) * 16)}; | |||||
| std::vector<size_t> output1_dims{input_i_shape[1], input_i_shape[2]}; | |||||
| AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32}, {output0_dims, output1_dims}, | |||||
| basic_lstm_cell_c_state_grad.get()); | |||||
| AnfAlgo::SetNodeAttr("forget_bias", MakeValue(1.0f), basic_lstm_cell_c_state_grad); | |||||
| AnfAlgo::SetNodeAttr("activation", MakeValue("Tanh"), basic_lstm_cell_c_state_grad); | |||||
| // Create matmul | |||||
| auto origin_input1_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(2), 0); | |||||
| std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name()))}; | |||||
| auto matmul = func_graph->NewCNode(matmul_inputs); | |||||
| AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {{output0_dims[0], origin_input1_shape[0]}}, | |||||
| matmul.get()); | |||||
| AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(false), matmul); | |||||
| AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(true), matmul); | |||||
| // Create split | |||||
| std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))}; | |||||
| auto split_v = func_graph->NewCNode(splitv_input); | |||||
| auto origin_output2_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 2); | |||||
| auto origin_output3_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 3); | |||||
| std::vector<size_t> split_v_output0_shape{origin_output2_shape[1], origin_output2_shape[2]}; | |||||
| std::vector<size_t> split_v_output1_shape{origin_output3_shape[0], origin_output3_shape[1]}; | |||||
| AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32}, | |||||
| {split_v_output0_shape, split_v_output1_shape}, split_v.get()); | |||||
| AnfAlgo::SetNodeAttr(kAttrSizeSplits, | |||||
| MakeValue(std::vector<int>{SizeToInt((origin_output2_shape[2] + 15) / 16), | |||||
| SizeToInt((origin_output3_shape[1] + 15) / 16)}), | |||||
| split_v); | |||||
| AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(0), split_v); | |||||
| AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(2), split_v); | |||||
| basic_lstm_cell_c_state_grad_nodes.emplace_back(basic_lstm_cell_c_state_grad); | |||||
| matmul_nodes.emplace_back(matmul); | |||||
| split_nodes.emplace_back(split_v); | |||||
| } | |||||
| result_nodes->emplace_back(basic_lstm_cell_c_state_grad_nodes); | |||||
| result_nodes->emplace_back(matmul_nodes); | |||||
| result_nodes->emplace_back(split_nodes); | |||||
| } | |||||
| AnfNodePtr CreateLSTMSPlitV(const FuncGraphPtr &func_graph, const AnfNodePtr &input, | |||||
| const std::vector<std::vector<size_t>> &split_shapes, | |||||
| const std::vector<TypeId> &split_types, const std::vector<int> &size_split, | |||||
| size_t num_split_x) { | |||||
| std::vector<AnfNodePtr> lstm_split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())), | |||||
| input}; | |||||
| auto lstm_split = func_graph->NewCNode(lstm_split_input); | |||||
| AnfAlgo::SetOutputInferTypeAndShape(split_types, split_shapes, lstm_split.get()); | |||||
| AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(size_split), lstm_split); | |||||
| AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(0), lstm_split); | |||||
| AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(SizeToInt(num_split_x)), lstm_split); | |||||
| return lstm_split; | |||||
| } | |||||
| AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode, | |||||
| std::vector<AnfNodePtr> *outputs) { | |||||
| std::vector<std::vector<AnfNodePtr>> result_nodes; | |||||
| CreateTLoopNode(func_graph, dynamic_rnn_grad_cnode, &result_nodes); | |||||
| auto origin_input5_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(6), 0); | |||||
| std::vector<size_t> split_c_dims{1, origin_input5_shape[0], origin_input5_shape[1]}; | |||||
| auto origin_input7 = dynamic_rnn_grad_cnode->input(8); | |||||
| size_t num_split_x = AnfAlgo::GetOutputInferShape(origin_input7, 0)[0]; | |||||
| std::vector<std::vector<size_t>> split_shapes; | |||||
| std::vector<TypeId> split_types; | |||||
| std::vector<int> size_split; | |||||
| for (size_t i = 0; i < num_split_x; ++i) { | |||||
| split_shapes.emplace_back(split_c_dims); | |||||
| split_types.emplace_back(kNumberTypeFloat32); | |||||
| size_split.emplace_back(1); | |||||
| } | |||||
| // Create lstm_split_c | |||||
| auto lstm_split_c = CreateLSTMSPlitV(func_graph, origin_input7, split_shapes, split_types, size_split, num_split_x); | |||||
| std::vector<AnfNodePtr> lstm_split_c_outputs; | |||||
| CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_c, num_split_x, &lstm_split_c_outputs); | |||||
| // Create lstm_split_dy | |||||
| auto lstm_split_dy = | |||||
| CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(9), split_shapes, split_types, size_split, num_split_x); | |||||
| std::vector<AnfNodePtr> lstm_split_dy_outputs; | |||||
| CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_dy, num_split_x, &lstm_split_dy_outputs); | |||||
| // Create lstm_split_i | |||||
| auto lstm_split_i = | |||||
| CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(12), split_shapes, split_types, size_split, num_split_x); | |||||
| std::vector<AnfNodePtr> lstm_split_i_outputs; | |||||
| CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_i, num_split_x, &lstm_split_i_outputs); | |||||
| // Create lstm_split_j | |||||
| auto lstm_split_j = | |||||
| CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(13), split_shapes, split_types, size_split, num_split_x); | |||||
| std::vector<AnfNodePtr> lstm_split_j_outputs; | |||||
| CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_j, num_split_x, &lstm_split_j_outputs); | |||||
| // Create lstm_split_f | |||||
| auto lstm_split_f = | |||||
| CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(14), split_shapes, split_types, size_split, num_split_x); | |||||
| std::vector<AnfNodePtr> lstm_split_f_outputs; | |||||
| CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_f, num_split_x, &lstm_split_f_outputs); | |||||
| // Create lstm_split_o | |||||
| auto lstm_split_o = | |||||
| CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(15), split_shapes, split_types, size_split, num_split_x); | |||||
| std::vector<AnfNodePtr> lstm_split_o_outputs; | |||||
| CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_o, num_split_x, &lstm_split_o_outputs); | |||||
| // Create lstm_split_tanh | |||||
| auto lstm_split_tanh = | |||||
| CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(16), split_shapes, split_types, size_split, num_split_x); | |||||
| std::vector<AnfNodePtr> lstm_split_tanh_outputs; | |||||
| CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_tanh, num_split_x, &lstm_split_tanh_outputs); | |||||
| // Add edges | |||||
| std::vector<AnfNodePtr> pre_basic_lstm_cell_c_state_grad_outputs; | |||||
| std::vector<AnfNodePtr> pre_split_outputs; | |||||
| auto basic_lstm_cell_c_state_grad_nodes = result_nodes[0]; | |||||
| auto matmul_nodes = result_nodes[1]; | |||||
| auto split_nodes = result_nodes[2]; | |||||
| std::vector<AnfNodePtr> lstm_x_concat_input(num_split_x + 1); | |||||
| lstm_x_concat_input[0] = NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())); | |||||
| std::vector<AnfNodePtr> lstm_gage_concat_input(num_split_x + 1); | |||||
| lstm_gage_concat_input[0] = NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())); | |||||
| for (size_t i = 0; i < num_split_x; ++i) { | |||||
| size_t idx = num_split_x - i - 1; | |||||
| // Create basic_lstm_cell_c_state_grad | |||||
| std::vector<AnfNodePtr> basic_lstm_cell_c_state_grad_inputs = { | |||||
| NewValueNode(std::make_shared<Primitive>(kBasicLSTMCellCStateGradV2OpName))}; | |||||
| if (i == num_split_x - 1) { | |||||
| std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())), | |||||
| dynamic_rnn_grad_cnode->input(6)}; | |||||
| auto reshape = func_graph->NewCNode(reshape_inputs); | |||||
| auto reshape_out_shape = {IntToSize(1), AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(6), 0)[0], | |||||
| AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(6), 0)[1]}; | |||||
| AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {reshape_out_shape}, reshape.get()); | |||||
| basic_lstm_cell_c_state_grad_inputs.emplace_back(reshape); | |||||
| } else { | |||||
| basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_c_outputs[idx - 1]); | |||||
| } | |||||
| basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_dy_outputs[idx]); | |||||
| if (i == 0) { | |||||
| basic_lstm_cell_c_state_grad_inputs.emplace_back(dynamic_rnn_grad_cnode->input(10)); | |||||
| basic_lstm_cell_c_state_grad_inputs.emplace_back(dynamic_rnn_grad_cnode->input(11)); | |||||
| } else { | |||||
| basic_lstm_cell_c_state_grad_inputs.emplace_back(pre_split_outputs[1]); | |||||
| basic_lstm_cell_c_state_grad_inputs.emplace_back(pre_basic_lstm_cell_c_state_grad_outputs[1]); | |||||
| } | |||||
| basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_i_outputs[idx]); | |||||
| basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_j_outputs[idx]); | |||||
| basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_f_outputs[idx]); | |||||
| basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_o_outputs[idx]); | |||||
| basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_tanh_outputs[idx]); | |||||
| auto basic_lstm_cell_c_state_grad = func_graph->NewCNode(basic_lstm_cell_c_state_grad_inputs); | |||||
| MS_EXCEPTION_IF_NULL(basic_lstm_cell_c_state_grad); | |||||
| basic_lstm_cell_c_state_grad->set_abstract(basic_lstm_cell_c_state_grad_nodes[i]->abstract()); | |||||
| AnfAlgo::CopyNodeAttrs(basic_lstm_cell_c_state_grad_nodes[i], basic_lstm_cell_c_state_grad); | |||||
| // Create outputs for current basic_lstm_cell_c_state_grad node | |||||
| std::vector<AnfNodePtr> basic_lstm_cell_c_state_grad_outputs; | |||||
| CreateMultipleOutputsOfAnfNode(func_graph, basic_lstm_cell_c_state_grad, 2, &basic_lstm_cell_c_state_grad_outputs); | |||||
| pre_basic_lstm_cell_c_state_grad_outputs = basic_lstm_cell_c_state_grad_outputs; | |||||
| // Create MatMul | |||||
| std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name()))}; | |||||
| matmul_inputs.emplace_back(basic_lstm_cell_c_state_grad_outputs[0]); | |||||
| matmul_inputs.emplace_back(dynamic_rnn_grad_cnode->input(2)); | |||||
| auto matmul = func_graph->NewCNode(matmul_inputs); | |||||
| MS_EXCEPTION_IF_NULL(matmul); | |||||
| matmul->set_abstract(matmul_nodes[i]->abstract()); | |||||
| AnfAlgo::CopyNodeAttrs(matmul_nodes[i], matmul); | |||||
| // Create splitv | |||||
| std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())), | |||||
| matmul}; | |||||
| auto split_v = func_graph->NewCNode(splitv_input); | |||||
| MS_EXCEPTION_IF_NULL(split_v); | |||||
| split_v->set_abstract(split_nodes[i]->abstract()); | |||||
| AnfAlgo::CopyNodeAttrs(split_nodes[i], split_v); | |||||
| // Create outputs for current split node | |||||
| std::vector<AnfNodePtr> split_outputs; | |||||
| CreateMultipleOutputsOfAnfNode(func_graph, split_v, 2, &split_outputs); | |||||
| pre_split_outputs = split_outputs; | |||||
| lstm_x_concat_input[idx + 1] = split_outputs[0]; | |||||
| lstm_gage_concat_input[idx + 1] = basic_lstm_cell_c_state_grad_outputs[0]; | |||||
| } | |||||
| // Create lstm_x_concat | |||||
| auto lstm_x_concat = func_graph->NewCNode(lstm_x_concat_input); | |||||
| AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 2)}, | |||||
| lstm_x_concat.get()); | |||||
| AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(num_split_x)), lstm_x_concat); | |||||
| AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{SizeToInt(num_split_x)}), lstm_x_concat); | |||||
| AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(0), lstm_x_concat); | |||||
| // Create lstm_gage_concat | |||||
| auto lstm_gage_concat = func_graph->NewCNode(lstm_gage_concat_input); | |||||
| auto origin_input7_shape = AnfAlgo::GetOutputInferShape(origin_input7, 0); | |||||
| AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, | |||||
| {{origin_input7_shape[0], origin_input7_shape[1], 4 * origin_input7_shape[2]}}, | |||||
| lstm_gage_concat.get()); | |||||
| AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(num_split_x)), lstm_gage_concat); | |||||
| AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{SizeToInt(num_split_x)}), lstm_gage_concat); | |||||
| AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(0), lstm_gage_concat); | |||||
| outputs->emplace_back(lstm_x_concat); | |||||
| outputs->emplace_back(pre_split_outputs[1]); | |||||
| outputs->emplace_back(pre_basic_lstm_cell_c_state_grad_outputs[1]); | |||||
| return lstm_gage_concat; | |||||
| } | |||||
| AnfNodePtr CreateSplitV(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode); | |||||
| // Create node | |||||
| auto origin_input6 = dynamic_rnn_grad_cnode->input(7); | |||||
| std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())), | |||||
| origin_input6}; | |||||
| auto split_v = func_graph->NewCNode(splitv_input); | |||||
| // Set infer data type and shape | |||||
| auto dtypes = {AnfAlgo::GetOutputInferDataType(origin_input6, 0), AnfAlgo::GetOutputInferDataType(origin_input6, 0)}; | |||||
| auto origin_input6_shape = AnfAlgo::GetOutputInferShape(origin_input6, 0); | |||||
| std::vector<size_t> shape1 = {origin_input6_shape[0] - 1, origin_input6_shape[1], origin_input6_shape[2]}; | |||||
| std::vector<size_t> shape2 = {1, origin_input6_shape[1], origin_input6_shape[2]}; | |||||
| std::vector<std::vector<size_t>> shapes = {shape1, shape2}; | |||||
| AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_v.get()); | |||||
| // Set attr | |||||
| AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(0), split_v); | |||||
| AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(2), split_v); | |||||
| AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(std::vector<int>{SizeToInt(origin_input6_shape[0] - 1), 1}), split_v); | |||||
| AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), split_v); | |||||
| return split_v; | |||||
| } | |||||
| AnfNodePtr CreateHConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode, | |||||
| const AnfNodePtr &splitv) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode); | |||||
| MS_EXCEPTION_IF_NULL(splitv); | |||||
| // Create node | |||||
| std::vector<AnfNodePtr> splitv_outputs; | |||||
| CreateMultipleOutputsOfAnfNode(func_graph, splitv, kSplitVOutputNum, &splitv_outputs); | |||||
| if (splitv_outputs.size() != kSplitVOutputNum) { | |||||
| MS_LOG(EXCEPTION) << "Create outputs of node " << splitv->DebugString() << " failed"; | |||||
| } | |||||
| auto origin_input4 = dynamic_rnn_grad_cnode->input(5); | |||||
| auto origin_input4_shape = AnfAlgo::GetOutputInferShape(origin_input4, 0); | |||||
| // Create reshape to change shape | |||||
| std::vector<size_t> shape_tmp; | |||||
| if (origin_input4_shape.size() == 3) { | |||||
| shape_tmp = origin_input4_shape; | |||||
| } else { | |||||
| shape_tmp = {1, origin_input4_shape[0], origin_input4_shape[1]}; | |||||
| } | |||||
| std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())), | |||||
| origin_input4}; | |||||
| auto reshape = func_graph->NewCNode(reshape_input); | |||||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape); | |||||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape_tmp}, reshape.get()); | |||||
| std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())), | |||||
| reshape, splitv_outputs[0]}; | |||||
| auto concat = func_graph->NewCNode(concat_inputs); | |||||
| // Set infer data type and shape | |||||
| auto splitv_output0_shape = AnfAlgo::GetOutputInferShape(splitv, 0); | |||||
| std::vector<size_t> shape = {splitv_output0_shape[0] + 1, origin_input4_shape[0], origin_input4_shape[1]}; | |||||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape}, concat.get()); | |||||
| // Set attr | |||||
| AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat); | |||||
| AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{2}), concat); | |||||
| AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(0), concat); | |||||
| AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat); | |||||
| return concat; | |||||
| } | |||||
| AnfNodePtr CreateConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode, | |||||
| const AnfNodePtr &h_concat) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode); | |||||
| // Create node | |||||
| auto origin_input0 = dynamic_rnn_grad_cnode->input(1); | |||||
| std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())), | |||||
| origin_input0, h_concat}; | |||||
| auto concat = func_graph->NewCNode(concat_inputs); | |||||
| // Set infer data type and shape | |||||
| auto origin_output0_shape = AnfAlgo::GetOutputInferShape(origin_input0, 0); | |||||
| auto h_concat_output_shape = AnfAlgo::GetOutputInferShape(h_concat, 0); | |||||
| std::vector<size_t> shape = {origin_output0_shape[0], origin_output0_shape[1], | |||||
| origin_output0_shape[2] + h_concat_output_shape[2]}; | |||||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input0, 0)}, {shape}, concat.get()); | |||||
| // Set attr | |||||
| AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat); | |||||
| AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{2}), concat); | |||||
| AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(2), concat); | |||||
| AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat); | |||||
| return concat; | |||||
| } | |||||
| AnfNodePtr CreateConcatNodeT1(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode); | |||||
| // Create node | |||||
| auto origin_input0 = dynamic_rnn_grad_cnode->input(1); | |||||
| auto origin_input4 = dynamic_rnn_grad_cnode->input(5); | |||||
| auto origin_input4_shape = AnfAlgo::GetOutputInferShape(origin_input4, 0); | |||||
| // Create reshape to change shape | |||||
| std::vector<size_t> shape_tmp; | |||||
| if (origin_input4_shape.size() == 3) { | |||||
| shape_tmp = origin_input4_shape; | |||||
| } else { | |||||
| shape_tmp = {1, origin_input4_shape[0], origin_input4_shape[1]}; | |||||
| } | |||||
| std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())), | |||||
| origin_input4}; | |||||
| auto reshape = func_graph->NewCNode(reshape_input); | |||||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape); | |||||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape_tmp}, reshape.get()); | |||||
| std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())), | |||||
| origin_input0, reshape}; | |||||
| auto concat = func_graph->NewCNode(concat_inputs); | |||||
| // Set infer data type and shape | |||||
| auto origin_input0_shape = AnfAlgo::GetOutputInferShape(origin_input0, 0); | |||||
| std::vector<size_t> shape = {origin_input0_shape[0], origin_input0_shape[1], origin_input0_shape[2] + shape_tmp[2]}; | |||||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input0, 0)}, {shape}, concat.get()); | |||||
| // Set attr | |||||
| AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat); | |||||
| AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{2}), concat); | |||||
| AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(2), concat); | |||||
| AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat); | |||||
| return concat; | |||||
| } | |||||
| AnfNodePtr CreateBatchMatMul(const FuncGraphPtr &func_graph, const AnfNodePtr &lstm_input_grad, | |||||
| const AnfNodePtr &concat) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| // Create node | |||||
| std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMul->name())), | |||||
| concat, lstm_input_grad}; | |||||
| auto batch_matmul = func_graph->NewCNode(matmul_inputs); | |||||
| // Set infer data type and shape | |||||
| auto concat_shape = AnfAlgo::GetOutputInferShape(concat, 0); | |||||
| auto lstm_input_grad_shape = AnfAlgo::GetOutputInferShape(lstm_input_grad, 0); | |||||
| std::vector<size_t> shape = {concat_shape[0], concat_shape[2], lstm_input_grad_shape[2]}; | |||||
| AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {shape}, batch_matmul.get()); | |||||
| // Set attr | |||||
| AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), batch_matmul); | |||||
| AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(true), batch_matmul); | |||||
| AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(false), batch_matmul); | |||||
| return batch_matmul; | |||||
| } | |||||
| AnfNodePtr CreateDwReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode, | |||||
| const AnfNodePtr &batch_matmul) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| // Create node | |||||
| std::vector<AnfNodePtr> reduce_sum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())), | |||||
| batch_matmul}; | |||||
| auto reduce_sum = func_graph->NewCNode(reduce_sum_inputs); | |||||
| // Set infer data type and shape | |||||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)}, | |||||
| {AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 0)}, reduce_sum.get()); | |||||
| // Set attr | |||||
| AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int>{0}), reduce_sum); | |||||
| AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(false), reduce_sum); | |||||
| AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum); | |||||
| return reduce_sum; | |||||
| } | |||||
| AnfNodePtr CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode, | |||||
| const AnfNodePtr &lstm_input_grad) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| // Create node | |||||
| std::vector<AnfNodePtr> reduce_sum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())), | |||||
| lstm_input_grad}; | |||||
| auto reduce_sum = func_graph->NewCNode(reduce_sum_inputs); | |||||
| // Set infer data type and shape | |||||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 1)}, | |||||
| {AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 1)}, reduce_sum.get()); | |||||
| // Set attr | |||||
| AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int>{0, 1}), reduce_sum); | |||||
| AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(false), reduce_sum); | |||||
| AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum); | |||||
| return reduce_sum; | |||||
| } | |||||
| } // namespace | |||||
| const BaseRef DynamicRnnGradFissionV2::DefinePattern() const { | |||||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||||
| return VectorRef({prim::kPrimDynamicRNNGrad, Xs}); | |||||
| } | |||||
| const AnfNodePtr DynamicRnnGradFissionV2::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const EquivPtr &) const { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto dynamic_rnn_grad_cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode); | |||||
| if (dynamic_rnn_grad_cnode->inputs().size() < kDynamicRNNGradInputNum + 1) { | |||||
| MS_LOG(INFO) << "The node " << dynamic_rnn_grad_cnode->DebugString() << " has less than " | |||||
| << kDynamicRNNGradInputNum + 1 << " inputs"; | |||||
| return nullptr; | |||||
| } | |||||
| std::vector<AnfNodePtr> new_outputs; | |||||
| auto lstm_input_grad = AddLSTMInputGradNode(func_graph, dynamic_rnn_grad_cnode, &new_outputs); | |||||
| size_t t_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(7), 0)[0]; | |||||
| AnfNodePtr concat = nullptr; | |||||
| if (t_size != 1) { | |||||
| auto splitv = CreateSplitV(func_graph, dynamic_rnn_grad_cnode); | |||||
| auto h_concat = CreateHConcat(func_graph, dynamic_rnn_grad_cnode, splitv); | |||||
| concat = CreateConcat(func_graph, dynamic_rnn_grad_cnode, h_concat); | |||||
| } else { | |||||
| concat = CreateConcatNodeT1(func_graph, dynamic_rnn_grad_cnode); | |||||
| } | |||||
| auto batch_matmul = CreateBatchMatMul(func_graph, lstm_input_grad, concat); | |||||
| std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; | |||||
| if (t_size != 1) { | |||||
| auto dw_reduce_sum = CreateDwReduceSum(func_graph, dynamic_rnn_grad_cnode, batch_matmul); | |||||
| make_tuple_inputs.emplace_back(dw_reduce_sum); | |||||
| } else { | |||||
| make_tuple_inputs.emplace_back(batch_matmul); | |||||
| } | |||||
| // create reduce_sum_2 | |||||
| auto db_reduce_sum = CreateDbReduceSum(func_graph, dynamic_rnn_grad_cnode, lstm_input_grad); | |||||
| make_tuple_inputs.emplace_back(db_reduce_sum); | |||||
| make_tuple_inputs.insert(make_tuple_inputs.end(), new_outputs.begin(), new_outputs.end()); | |||||
| auto make_tuple = func_graph->NewCNode(make_tuple_inputs); | |||||
| return make_tuple; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -13,21 +13,22 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_H_ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_V2_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_V2_H_ | |||||
| #include "backend/optimizer/common/optimizer.h" | #include "backend/optimizer/common/optimizer.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| class DynamicRNNGradFission : public PatternProcessPass { | |||||
| class DynamicRnnGradFissionV2 : public PatternProcessPass { | |||||
| public: | public: | ||||
| explicit DynamicRNNGradFission(bool multigraph = true) : PatternProcessPass("dynamic_rnn_grad_fission", multigraph) {} | |||||
| ~DynamicRNNGradFission() override = default; | |||||
| explicit DynamicRnnGradFissionV2(bool multigraph = true) | |||||
| : PatternProcessPass("dynamic_rnn_grad_fission_v2", multigraph) {} | |||||
| ~DynamicRnnGradFissionV2() override = default; | |||||
| const BaseRef DefinePattern() const override; | const BaseRef DefinePattern() const override; | ||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | ||||
| }; | }; | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_H_ | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_V2_H_ | |||||
| @@ -232,6 +232,9 @@ constexpr auto kSparseApplyFtrlName = "SparseApplyFtrl"; | |||||
| constexpr auto kSparseApplyFtrlV2Name = "SparseApplyFtrlV2"; | constexpr auto kSparseApplyFtrlV2Name = "SparseApplyFtrlV2"; | ||||
| constexpr auto kSGDName = "SGD"; | constexpr auto kSGDName = "SGD"; | ||||
| constexpr auto kLARSUpdateName = "LARSUpdate"; | constexpr auto kLARSUpdateName = "LARSUpdate"; | ||||
| constexpr auto kBasicLSTMCellCStateGradOpName = "BasicLSTMCellCStateGrad"; | |||||
| constexpr auto kBasicLSTMCellCStateGradV2OpName = "BasicLSTMCellCStateGradV2"; | |||||
| constexpr auto kMatMulV2OpName = "MatMulV2"; | |||||
| // Hcom Op Type | // Hcom Op Type | ||||
| constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce"; | constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce"; | ||||
| @@ -282,6 +282,7 @@ from .inv import _inv_tbe | |||||
| from .inv_grad import _inv_grad_tbe | from .inv_grad import _inv_grad_tbe | ||||
| from .invert import _invert_tbe | from .invert import _invert_tbe | ||||
| from .basic_lstm_cell import _basic_lstm_cell_tbe | from .basic_lstm_cell import _basic_lstm_cell_tbe | ||||
| from .basic_lstm_cell_c_state_grad_v2 import _basic_lstm_cell_c_state_grad_tbe_v2 | |||||
| from .basic_lstm_cell_c_state_grad import _basic_lstm_cell_c_state_grad_tbe | from .basic_lstm_cell_c_state_grad import _basic_lstm_cell_c_state_grad_tbe | ||||
| from .basic_lstm_cell_weight_grad import _basic_lstm_cell_weight_grad_tbe | from .basic_lstm_cell_weight_grad import _basic_lstm_cell_weight_grad_tbe | ||||
| from .basic_lstm_cell_input_grad import _basic_lstm_cell_input_grad_tbe | from .basic_lstm_cell_input_grad import _basic_lstm_cell_input_grad_tbe | ||||
| @@ -0,0 +1,51 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """BasicLSTMCellCStateGradV2 op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| basic_lstm_cell_c_state_grad_op_info_v2 = TBERegOp("BasicLSTMCellCStateGradV2") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("basic_lstm_cell_c_state_grad.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("basic_lstm_cell_c_state_grad_v2") \ | |||||
| .attr("forget_bias", "optional", "float", "all") \ | |||||
| .attr("activation", "optional", "str", "all") \ | |||||
| .partial_flag(True) \ | |||||
| .input(0, "c", False, "required", "all") \ | |||||
| .input(1, "dy", False, "required", "all") \ | |||||
| .input(2, "dht", False, "required", "all") \ | |||||
| .input(3, "dct", False, "required", "all") \ | |||||
| .input(4, "it", False, "required", "all") \ | |||||
| .input(5, "jt", False, "required", "all") \ | |||||
| .input(6, "ft", False, "required", "all") \ | |||||
| .input(7, "ot", False, "required", "all") \ | |||||
| .input(8, "tanhct", False, "required", "all") \ | |||||
| .output(0, "dgate", False, "required", "all") \ | |||||
| .output(1, "dct_1", False, "required", "all") \ | |||||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||||
| DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||||
| DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \ | |||||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||||
| DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||||
| DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \ | |||||
| .get_op_info() | |||||
| @op_info_register(basic_lstm_cell_c_state_grad_op_info_v2) | |||||
| def _basic_lstm_cell_c_state_grad_tbe_v2(): | |||||
| """BasicLSTMCellCStateGradV2 TBE register""" | |||||
| return | |||||