From 99ade12fd6f30593bb8bb2e2ba8069c2c8d83f25 Mon Sep 17 00:00:00 2001 From: jjfeing Date: Tue, 24 Nov 2020 21:17:08 +0800 Subject: [PATCH] dynamic rnn split dynamic rnn split dynamic rnn split add reshape op and check attribute T check concat attribute T concat bug about output shape concat output shape --- .../ascend/ascend_backend_optimization.cc | 7 +- .../ascend/ir_fission/concat_fission.cc | 24 ++- .../ir_fission/dynamic_rnn_grad_fission_v2.cc | 23 +- .../ascend/ir_fission/splitv_fission.cc | 203 ++++++++++++++++++ .../ascend/ir_fission/splitv_fission.h | 38 ++++ 5 files changed, 284 insertions(+), 11 deletions(-) create mode 100644 mindspore/ccsrc/backend/optimizer/ascend/ir_fission/splitv_fission.cc create mode 100644 mindspore/ccsrc/backend/optimizer/ascend/ir_fission/splitv_fission.h diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index 8d650e7c06..cf5a26baa8 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -105,6 +105,7 @@ #include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.h" #include "backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.h" #include "backend/optimizer/ascend/ir_fission/split_fission.h" +#include "backend/optimizer/ascend/ir_fission/splitv_fission.h" #include "backend/optimizer/ascend/format_type/modify_ops_attrs.h" #include "backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h" #include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h" @@ -175,7 +176,10 @@ void AddAscendIRFusionPass(PassManager *ir_fusion_pm) { ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); @@ -289,8 +293,6 @@ void AscendBackendIRFusionOptimization(const std::shared_ptrAddPass(std::make_shared()); AddAscendIRFusionRulesPass(ir_fusion_pm.get()); AddAscendIRFusionPass(ir_fusion_pm.get()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); if (context_ptr->get_param(MS_CTX_ENABLE_TASK_SINK) && context_ptr->get_param(MS_CTX_ENABLE_LOOP_SINK) && ConfigManager::GetInstance().iter_num() > 1) { @@ -326,6 +328,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr(); auto ir_fusion_pm = std::make_shared("ir_fusion_pm"); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/concat_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/concat_fission.cc index c2c5355fcf..a17387e6bf 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/concat_fission.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/concat_fission.cc @@ -34,8 +34,12 @@ AnfNodePtr CreateNewConcat(const FuncGraphPtr &func_graph, const CNodePtr &origi MS_EXCEPTION_IF_NULL(new_concat); new_concat->set_scope(origin_concat_cnode->scope()); // Set attrs - AnfAlgo::CopyNodeAttr(kAttrAxis, origin_concat_cnode, new_concat); - AnfAlgo::CopyNodeAttr(kAttrT, origin_concat_cnode, new_concat); + if (AnfAlgo::HasNodeAttr(kAttrAxis, origin_concat_cnode)) { + AnfAlgo::CopyNodeAttr(kAttrAxis, origin_concat_cnode, new_concat); + } + if (AnfAlgo::HasNodeAttr(kAttrT, origin_concat_cnode)) { + AnfAlgo::CopyNodeAttr(kAttrT, origin_concat_cnode, new_concat); + } AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(offset)), new_concat); AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(SizeToLong(offset)), new_concat); std::vector dyn_input_sizes{SizeToLong(offset)}; @@ -51,7 +55,12 @@ AnfNodePtr CreateNewConcat(const FuncGraphPtr &func_graph, const CNodePtr &origi MS_LOG(EXCEPTION) << "The concat_dim value " << axis << "is out of range" << " trace: " << trace::DumpSourceLines(origin_concat_cnode); } - output_shape[axis] = input_shape[axis] * offset; + output_shape[axis] = 0; + for (size_t i = begin_index; i < begin_index + offset; ++i) { + input_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_concat_cnode, i - 1); + output_shape[axis] += input_shape[axis]; + } + // output_shape[axis] = input_shape[axis] * offset; AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_concat_cnode, 0)}, {output_shape}, new_concat.get()); return new_concat; @@ -95,8 +104,13 @@ const AnfNodePtr ConcatFission::Process(const FuncGraphPtr &func_graph, const An base_concat->set_scope(new_cnode->scope()); base_concat->set_abstract(new_cnode->abstract()); // Set attrs - AnfAlgo::CopyNodeAttr(kAttrAxis, new_cnode, base_concat); - AnfAlgo::CopyNodeAttr(kAttrT, new_cnode, base_concat); + if (AnfAlgo::HasNodeAttr(kAttrAxis, new_cnode)) { + AnfAlgo::CopyNodeAttr(kAttrAxis, new_cnode, base_concat); + } + if (AnfAlgo::HasNodeAttr(kAttrT, new_cnode)) { + AnfAlgo::CopyNodeAttr(kAttrT, new_cnode, base_concat); + } + AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(base_concat_inputs.size() - 1)), base_concat); AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(SizeToLong(base_concat_inputs.size() - 1)), base_concat); std::vector dyn_input_sizes{SizeToLong(base_concat_inputs.size() - 1)}; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.cc index 5ab380c715..d887f36613 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.cc @@ -71,10 +71,10 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn {split_v_output0_shape, split_v_output1_shape}, split_v.get()); AnfAlgo::SetNodeAttr(kAttrSizeSplits, - MakeValue(std::vector{SizeToLong((origin_output2_shape[2] + 15) / 16), - SizeToLong((origin_output3_shape[1] + 15) / 16)}), + MakeValue(std::vector{SizeToLong((origin_output2_shape[2] + 15) / 16 * 16), + SizeToLong((origin_output3_shape[1] + 15) / 16 * 16)}), split_v); - AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast(1)), split_v); + AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast(2)), split_v); AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(static_cast(2)), split_v); basic_lstm_cell_c_state_grad_nodes.emplace_back(basic_lstm_cell_c_state_grad); @@ -231,7 +231,22 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr & pre_split_outputs = split_outputs; lstm_x_concat_input[idx + 1] = split_outputs[0]; - lstm_gage_concat_input[idx + 1] = basic_lstm_cell_c_state_grad_outputs[0]; + + auto basic_lstm_cell_c_state_grad_outputs_0_shape = + AnfAlgo::GetOutputInferShape(basic_lstm_cell_c_state_grad_outputs[0], 0); + std::vector temp_shape; + if (basic_lstm_cell_c_state_grad_outputs_0_shape.size() == 3) { + temp_shape = basic_lstm_cell_c_state_grad_outputs_0_shape; + } else { + temp_shape = {1, basic_lstm_cell_c_state_grad_outputs_0_shape[0], + basic_lstm_cell_c_state_grad_outputs_0_shape[1]}; + } + std::vector reshape_input = {NewValueNode(std::make_shared(prim::kPrimReshape->name())), + basic_lstm_cell_c_state_grad_outputs[0]}; + auto reshape = func_graph->NewCNode(reshape_input); + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(basic_lstm_cell_c_state_grad_outputs[0], 0)}, + {temp_shape}, reshape.get()); + lstm_gage_concat_input[idx + 1] = reshape; } // Create lstm_x_concat diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/splitv_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/splitv_fission.cc new file mode 100644 index 0000000000..fcc0267a54 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/splitv_fission.cc @@ -0,0 +1,203 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/splitv_fission.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore::opt { +namespace { +CNodePtr CreateSplitVNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(input_node); + std::vector splitv_inputs{NewValueNode(std::make_shared(kSplitVOpName)), input_node}; + CNodePtr splitv = func_graph->NewCNode(splitv_inputs); + MS_EXCEPTION_IF_NULL(splitv); + splitv->set_scope(input_node->scope()); + return splitv; +} + +CNodePtr CreateBaseSplitVNode(const FuncGraphPtr &func_graph, const CNodePtr &origin_cnode) { + MS_EXCEPTION_IF_NULL(origin_cnode); + if (origin_cnode->inputs().size() < kSplitInputNum) { + MS_LOG(EXCEPTION) << "The input number of split: " << origin_cnode->DebugString() << " should be " + << kSplitInputNum - 1; + } + return CreateSplitVNode(func_graph, origin_cnode->input(1)); +} + +void SetAttrForSplitVNode(const AnfNodePtr &splitv, const std::vector &size_splits, int64_t split_dim, + int64_t num_split) { + AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(size_splits), splitv); + AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(split_dim), splitv); + AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(num_split), splitv); +} + +size_t GetSmallSplitSize(const AnfNodePtr &split_node, int64_t split_dim, int64_t num_split) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(split_node, 0); + if (split_dim < 0) { + split_dim += input_shape.size(); + } + if (LongToSize(split_dim) >= input_shape.size()) { + MS_LOG(EXCEPTION) << "The split_dim value should be less than the shape size of input 0"; + } + return input_shape[split_dim] / num_split; +} + +void AddNewOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &new_splitv, int64_t outputs_num, + std::vector *inputs) { + MS_EXCEPTION_IF_NULL(inputs); + std::vector new_splitv_output; + CreateMultipleOutputsOfAnfNode(func_graph, new_splitv, LongToSize(outputs_num), &new_splitv_output); + inputs->insert(inputs->end(), new_splitv_output.begin(), new_splitv_output.end()); +} + +AnfNodePtr CreateTupleGetItem(const FuncGraphPtr &func_graph, const AnfNodePtr &input, size_t index) { + MS_EXCEPTION_IF_NULL(func_graph); + auto idx = NewValueNode(SizeToLong(index)); + MS_EXCEPTION_IF_NULL(idx); + auto imm = std::make_shared(SizeToLong(index)); + auto abstract_scalar = std::make_shared(imm); + idx->set_abstract(abstract_scalar); + auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, idx}); + return tuple_getitem; +} + +void CreateOutputShapeAndTypeId(const CNodePtr &origin_cnode, int64_t split_dim, int64_t split_size, int64_t num_split, + std::vector *new_type_ids, + std::vector> *new_output_shapes) { + MS_EXCEPTION_IF_NULL(new_type_ids); + MS_EXCEPTION_IF_NULL(new_output_shapes); + auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); + if (split_dim < 0) { + split_dim += output_shape.size(); + } + output_shape[split_dim] = split_size; + TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0); + for (int64_t i = 0; i < num_split; ++i) { + new_type_ids->emplace_back(type_id); + new_output_shapes->emplace_back(output_shape); + } +} + +void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePtr &base_splitv, + const std::vector &base_splitv_outputs, + const std::vector &size_splits_base, int64_t split_dim, + int64_t num_split) { + SetAttrForSplitVNode(base_splitv, size_splits_base, split_dim, num_split); + auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); + TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0); + std::vector base_type_ids(num_split, type_id); + std::vector> base_output_shapes_base; + if (split_dim < 0) { + split_dim += output_shape.size(); + } + for (int64_t i = 0; i < num_split; ++i) { + output_shape[split_dim] = size_splits_base[i]; + base_output_shapes_base.emplace_back(output_shape); + AnfAlgo::SetOutputInferTypeAndShape({type_id}, {output_shape}, base_splitv_outputs[i].get()); + } + AnfAlgo::SetOutputInferTypeAndShape(base_type_ids, base_output_shapes_base, base_splitv.get()); +} + +AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int64_t num_split, int64_t divisor) { + MS_EXCEPTION_IF_NULL(func_graph); + auto split_dim = AnfAlgo::GetNodeAttr(cnode, kAttrSplitDim); + CNodePtr base_splitv = CreateBaseSplitVNode(func_graph, cnode); + + // Create new size_splits for "size_splits" attr of each new Splitv node which has full inputs. + auto small_split_size = SizeToLong(GetSmallSplitSize(cnode, split_dim, num_split)); + std::vector size_splits_new(divisor, small_split_size); + // Create new output shape and new output type id for each new Splitv node which has full inputs. + std::vector new_type_ids; + std::vector> new_output_shapes; + CreateOutputShapeAndTypeId(cnode, split_dim, small_split_size, divisor, &new_type_ids, &new_output_shapes); + + // Create make_tuple input to create a make_tuple for replacing the old Split node. + std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; + // Start to divide the outputs of Split. + std::vector size_splits_base; + std::vector base_splitv_outputs; + const auto base_split_size = divisor * small_split_size; + int64_t nodes_num = 0; + int64_t cur_output_index = 0; + while (num_split - cur_output_index > divisor) { + auto tuple_getitem = CreateTupleGetItem(func_graph, base_splitv, nodes_num); + base_splitv_outputs.push_back(tuple_getitem); + CNodePtr new_splitv = CreateSplitVNode(func_graph, tuple_getitem); + SetAttrForSplitVNode(new_splitv, size_splits_new, split_dim, divisor); + AnfAlgo::SetOutputInferTypeAndShape(new_type_ids, new_output_shapes, new_splitv.get()); + AddNewOutputs(func_graph, new_splitv, divisor, &make_tuple_inputs); + cur_output_index += divisor; + size_splits_base.emplace_back(base_split_size); + nodes_num++; + } + if (cur_output_index < num_split) { + auto last_node_num_split = num_split - cur_output_index; + if (last_node_num_split > 1) { + auto tuple_getitem = CreateTupleGetItem(func_graph, base_splitv, nodes_num); + base_splitv_outputs.push_back(tuple_getitem); + CNodePtr new_splitv = CreateSplitVNode(func_graph, tuple_getitem); + std::vector size_splits_new_last(last_node_num_split, small_split_size); + SetAttrForSplitVNode(new_splitv, size_splits_new_last, split_dim, last_node_num_split); + // Create new output shape and new output type id for the last Splitv node + std::vector last_new_type_ids; + std::vector> last_new_output_shapes; + CreateOutputShapeAndTypeId(cnode, split_dim, small_split_size, last_node_num_split, &last_new_type_ids, + &last_new_output_shapes); + AnfAlgo::SetOutputInferTypeAndShape(last_new_type_ids, last_new_output_shapes, new_splitv.get()); + AddNewOutputs(func_graph, new_splitv, last_node_num_split, &make_tuple_inputs); + size_splits_base.emplace_back(last_node_num_split * small_split_size); + } else { + auto tuple_getitem = CreateTupleGetItem(func_graph, base_splitv, nodes_num); + base_splitv_outputs.push_back(tuple_getitem); + make_tuple_inputs.emplace_back(tuple_getitem); + size_splits_base.emplace_back(small_split_size); + } + nodes_num++; + } + // Set Attr and abstract for the base splitv + SetAttrAndAbstractForBaseSplitv(cnode, base_splitv, base_splitv_outputs, size_splits_base, split_dim, nodes_num); + AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); + return make_tuple; +} +} // namespace + +const BaseRef SplitVFission::DefinePattern() const { + VarPtr Xs = std::make_shared(); + auto split_prim = std::make_shared(kSplitVOpName); + return VectorRef({split_prim, Xs}); +} + +const AnfNodePtr SplitVFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(node); + if (AnfAlgo::IsDynamicShape(node)) { + return nullptr; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + // Check output num + if (!AnfAlgo::HasNodeAttr(kAttrNumSplit, cnode)) { + return nullptr; + } + auto num_split = AnfAlgo::GetNodeAttr(cnode, kAttrNumSplit); + if (num_split <= outputs_divisor_) { + return nullptr; + } + return DoFission(func_graph, cnode, num_split, outputs_divisor_); +} +} // namespace mindspore::opt diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/splitv_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/splitv_fission.h new file mode 100644 index 0000000000..3998e1ca6b --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/splitv_fission.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_SPLITV_FISSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_SPLITV_FISSION_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class SplitVFission : public PatternProcessPass { + const int kSplitOutputsDivisor = 63; + + public: + explicit SplitVFission(bool multigraph = true) + : PatternProcessPass("split_fission", multigraph), outputs_divisor_(kSplitOutputsDivisor) {} + ~SplitVFission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + int64_t outputs_divisor_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_SPLITV_FISSION_H_