dynamic rnn split dynamic rnn split add reshape op and check attribute T check concat attribute T concat bug about output shape concat output shapetags/v1.1.0
| @@ -105,6 +105,7 @@ | |||||
| #include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.h" | #include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.h" | ||||
| #include "backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.h" | #include "backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.h" | ||||
| #include "backend/optimizer/ascend/ir_fission/split_fission.h" | #include "backend/optimizer/ascend/ir_fission/split_fission.h" | ||||
| #include "backend/optimizer/ascend/ir_fission/splitv_fission.h" | |||||
| #include "backend/optimizer/ascend/format_type/modify_ops_attrs.h" | #include "backend/optimizer/ascend/format_type/modify_ops_attrs.h" | ||||
| #include "backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h" | #include "backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h" | ||||
| #include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h" | #include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h" | ||||
| @@ -175,7 +176,10 @@ void AddAscendIRFusionPass(PassManager *ir_fusion_pm) { | |||||
| ir_fusion_pm->AddPass(std::make_shared<AddnFission>()); | ir_fusion_pm->AddPass(std::make_shared<AddnFission>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<DereluFusion>()); | ir_fusion_pm->AddPass(std::make_shared<DereluFusion>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>()); | ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicRNN>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<DynamicRnnGradFissionV2>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<SplitFission>()); | ir_fusion_pm->AddPass(std::make_shared<SplitFission>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<SplitVFission>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<TensorScatterUpdateFission>()); | ir_fusion_pm->AddPass(std::make_shared<TensorScatterUpdateFission>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<PackFission>()); | ir_fusion_pm->AddPass(std::make_shared<PackFission>()); | ||||
| @@ -289,8 +293,6 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap | |||||
| ir_fusion_pm->AddPass(std::make_shared<DynamicGRUV2GradFission>()); | ir_fusion_pm->AddPass(std::make_shared<DynamicGRUV2GradFission>()); | ||||
| AddAscendIRFusionRulesPass(ir_fusion_pm.get()); | AddAscendIRFusionRulesPass(ir_fusion_pm.get()); | ||||
| AddAscendIRFusionPass(ir_fusion_pm.get()); | AddAscendIRFusionPass(ir_fusion_pm.get()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicRNN>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<DynamicRnnGradFissionV2>()); | |||||
| if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) && context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK) && | if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) && context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK) && | ||||
| ConfigManager::GetInstance().iter_num() > 1) { | ConfigManager::GetInstance().iter_num() > 1) { | ||||
| @@ -326,6 +328,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne | |||||
| auto optimizer = std::make_shared<GraphOptimizer>(); | auto optimizer = std::make_shared<GraphOptimizer>(); | ||||
| auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm"); | auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm"); | ||||
| ir_fusion_pm->AddPass(std::make_shared<SplitFission>()); | ir_fusion_pm->AddPass(std::make_shared<SplitFission>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<SplitVFission>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<BnSplit>()); | ir_fusion_pm->AddPass(std::make_shared<BnSplit>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>()); | ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>()); | ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>()); | ||||
| @@ -34,8 +34,12 @@ AnfNodePtr CreateNewConcat(const FuncGraphPtr &func_graph, const CNodePtr &origi | |||||
| MS_EXCEPTION_IF_NULL(new_concat); | MS_EXCEPTION_IF_NULL(new_concat); | ||||
| new_concat->set_scope(origin_concat_cnode->scope()); | new_concat->set_scope(origin_concat_cnode->scope()); | ||||
| // Set attrs | // Set attrs | ||||
| AnfAlgo::CopyNodeAttr(kAttrAxis, origin_concat_cnode, new_concat); | |||||
| AnfAlgo::CopyNodeAttr(kAttrT, origin_concat_cnode, new_concat); | |||||
| if (AnfAlgo::HasNodeAttr(kAttrAxis, origin_concat_cnode)) { | |||||
| AnfAlgo::CopyNodeAttr(kAttrAxis, origin_concat_cnode, new_concat); | |||||
| } | |||||
| if (AnfAlgo::HasNodeAttr(kAttrT, origin_concat_cnode)) { | |||||
| AnfAlgo::CopyNodeAttr(kAttrT, origin_concat_cnode, new_concat); | |||||
| } | |||||
| AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(offset)), new_concat); | AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(offset)), new_concat); | ||||
| AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(SizeToLong(offset)), new_concat); | AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(SizeToLong(offset)), new_concat); | ||||
| std::vector<int64_t> dyn_input_sizes{SizeToLong(offset)}; | std::vector<int64_t> dyn_input_sizes{SizeToLong(offset)}; | ||||
| @@ -51,7 +55,12 @@ AnfNodePtr CreateNewConcat(const FuncGraphPtr &func_graph, const CNodePtr &origi | |||||
| MS_LOG(EXCEPTION) << "The concat_dim value " << axis << "is out of range" | MS_LOG(EXCEPTION) << "The concat_dim value " << axis << "is out of range" | ||||
| << " trace: " << trace::DumpSourceLines(origin_concat_cnode); | << " trace: " << trace::DumpSourceLines(origin_concat_cnode); | ||||
| } | } | ||||
| output_shape[axis] = input_shape[axis] * offset; | |||||
| output_shape[axis] = 0; | |||||
| for (size_t i = begin_index; i < begin_index + offset; ++i) { | |||||
| input_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_concat_cnode, i - 1); | |||||
| output_shape[axis] += input_shape[axis]; | |||||
| } | |||||
| // output_shape[axis] = input_shape[axis] * offset; | |||||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_concat_cnode, 0)}, {output_shape}, | AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_concat_cnode, 0)}, {output_shape}, | ||||
| new_concat.get()); | new_concat.get()); | ||||
| return new_concat; | return new_concat; | ||||
| @@ -95,8 +104,13 @@ const AnfNodePtr ConcatFission::Process(const FuncGraphPtr &func_graph, const An | |||||
| base_concat->set_scope(new_cnode->scope()); | base_concat->set_scope(new_cnode->scope()); | ||||
| base_concat->set_abstract(new_cnode->abstract()); | base_concat->set_abstract(new_cnode->abstract()); | ||||
| // Set attrs | // Set attrs | ||||
| AnfAlgo::CopyNodeAttr(kAttrAxis, new_cnode, base_concat); | |||||
| AnfAlgo::CopyNodeAttr(kAttrT, new_cnode, base_concat); | |||||
| if (AnfAlgo::HasNodeAttr(kAttrAxis, new_cnode)) { | |||||
| AnfAlgo::CopyNodeAttr(kAttrAxis, new_cnode, base_concat); | |||||
| } | |||||
| if (AnfAlgo::HasNodeAttr(kAttrT, new_cnode)) { | |||||
| AnfAlgo::CopyNodeAttr(kAttrT, new_cnode, base_concat); | |||||
| } | |||||
| AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(base_concat_inputs.size() - 1)), base_concat); | AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(base_concat_inputs.size() - 1)), base_concat); | ||||
| AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(SizeToLong(base_concat_inputs.size() - 1)), base_concat); | AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(SizeToLong(base_concat_inputs.size() - 1)), base_concat); | ||||
| std::vector<int64_t> dyn_input_sizes{SizeToLong(base_concat_inputs.size() - 1)}; | std::vector<int64_t> dyn_input_sizes{SizeToLong(base_concat_inputs.size() - 1)}; | ||||
| @@ -71,10 +71,10 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn | |||||
| {split_v_output0_shape, split_v_output1_shape}, split_v.get()); | {split_v_output0_shape, split_v_output1_shape}, split_v.get()); | ||||
| AnfAlgo::SetNodeAttr(kAttrSizeSplits, | AnfAlgo::SetNodeAttr(kAttrSizeSplits, | ||||
| MakeValue(std::vector<int64_t>{SizeToLong((origin_output2_shape[2] + 15) / 16), | |||||
| SizeToLong((origin_output3_shape[1] + 15) / 16)}), | |||||
| MakeValue(std::vector<int64_t>{SizeToLong((origin_output2_shape[2] + 15) / 16 * 16), | |||||
| SizeToLong((origin_output3_shape[1] + 15) / 16 * 16)}), | |||||
| split_v); | split_v); | ||||
| AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast<int64_t>(1)), split_v); | |||||
| AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast<int64_t>(2)), split_v); | |||||
| AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(static_cast<int64_t>(2)), split_v); | AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(static_cast<int64_t>(2)), split_v); | ||||
| basic_lstm_cell_c_state_grad_nodes.emplace_back(basic_lstm_cell_c_state_grad); | basic_lstm_cell_c_state_grad_nodes.emplace_back(basic_lstm_cell_c_state_grad); | ||||
| @@ -231,7 +231,22 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr & | |||||
| pre_split_outputs = split_outputs; | pre_split_outputs = split_outputs; | ||||
| lstm_x_concat_input[idx + 1] = split_outputs[0]; | lstm_x_concat_input[idx + 1] = split_outputs[0]; | ||||
| lstm_gage_concat_input[idx + 1] = basic_lstm_cell_c_state_grad_outputs[0]; | |||||
| auto basic_lstm_cell_c_state_grad_outputs_0_shape = | |||||
| AnfAlgo::GetOutputInferShape(basic_lstm_cell_c_state_grad_outputs[0], 0); | |||||
| std::vector<size_t> temp_shape; | |||||
| if (basic_lstm_cell_c_state_grad_outputs_0_shape.size() == 3) { | |||||
| temp_shape = basic_lstm_cell_c_state_grad_outputs_0_shape; | |||||
| } else { | |||||
| temp_shape = {1, basic_lstm_cell_c_state_grad_outputs_0_shape[0], | |||||
| basic_lstm_cell_c_state_grad_outputs_0_shape[1]}; | |||||
| } | |||||
| std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())), | |||||
| basic_lstm_cell_c_state_grad_outputs[0]}; | |||||
| auto reshape = func_graph->NewCNode(reshape_input); | |||||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(basic_lstm_cell_c_state_grad_outputs[0], 0)}, | |||||
| {temp_shape}, reshape.get()); | |||||
| lstm_gage_concat_input[idx + 1] = reshape; | |||||
| } | } | ||||
| // Create lstm_x_concat | // Create lstm_x_concat | ||||
| @@ -0,0 +1,203 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/ascend/ir_fission/splitv_fission.h" | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| namespace mindspore::opt { | |||||
| namespace { | |||||
| CNodePtr CreateSplitVNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| MS_EXCEPTION_IF_NULL(input_node); | |||||
| std::vector<AnfNodePtr> splitv_inputs{NewValueNode(std::make_shared<Primitive>(kSplitVOpName)), input_node}; | |||||
| CNodePtr splitv = func_graph->NewCNode(splitv_inputs); | |||||
| MS_EXCEPTION_IF_NULL(splitv); | |||||
| splitv->set_scope(input_node->scope()); | |||||
| return splitv; | |||||
| } | |||||
| CNodePtr CreateBaseSplitVNode(const FuncGraphPtr &func_graph, const CNodePtr &origin_cnode) { | |||||
| MS_EXCEPTION_IF_NULL(origin_cnode); | |||||
| if (origin_cnode->inputs().size() < kSplitInputNum) { | |||||
| MS_LOG(EXCEPTION) << "The input number of split: " << origin_cnode->DebugString() << " should be " | |||||
| << kSplitInputNum - 1; | |||||
| } | |||||
| return CreateSplitVNode(func_graph, origin_cnode->input(1)); | |||||
| } | |||||
| void SetAttrForSplitVNode(const AnfNodePtr &splitv, const std::vector<int64_t> &size_splits, int64_t split_dim, | |||||
| int64_t num_split) { | |||||
| AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(size_splits), splitv); | |||||
| AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(split_dim), splitv); | |||||
| AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(num_split), splitv); | |||||
| } | |||||
| size_t GetSmallSplitSize(const AnfNodePtr &split_node, int64_t split_dim, int64_t num_split) { | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(split_node, 0); | |||||
| if (split_dim < 0) { | |||||
| split_dim += input_shape.size(); | |||||
| } | |||||
| if (LongToSize(split_dim) >= input_shape.size()) { | |||||
| MS_LOG(EXCEPTION) << "The split_dim value should be less than the shape size of input 0"; | |||||
| } | |||||
| return input_shape[split_dim] / num_split; | |||||
| } | |||||
| void AddNewOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &new_splitv, int64_t outputs_num, | |||||
| std::vector<AnfNodePtr> *inputs) { | |||||
| MS_EXCEPTION_IF_NULL(inputs); | |||||
| std::vector<AnfNodePtr> new_splitv_output; | |||||
| CreateMultipleOutputsOfAnfNode(func_graph, new_splitv, LongToSize(outputs_num), &new_splitv_output); | |||||
| inputs->insert(inputs->end(), new_splitv_output.begin(), new_splitv_output.end()); | |||||
| } | |||||
| AnfNodePtr CreateTupleGetItem(const FuncGraphPtr &func_graph, const AnfNodePtr &input, size_t index) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| auto idx = NewValueNode(SizeToLong(index)); | |||||
| MS_EXCEPTION_IF_NULL(idx); | |||||
| auto imm = std::make_shared<Int64Imm>(SizeToLong(index)); | |||||
| auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm); | |||||
| idx->set_abstract(abstract_scalar); | |||||
| auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, idx}); | |||||
| return tuple_getitem; | |||||
| } | |||||
| void CreateOutputShapeAndTypeId(const CNodePtr &origin_cnode, int64_t split_dim, int64_t split_size, int64_t num_split, | |||||
| std::vector<TypeId> *new_type_ids, | |||||
| std::vector<std::vector<size_t>> *new_output_shapes) { | |||||
| MS_EXCEPTION_IF_NULL(new_type_ids); | |||||
| MS_EXCEPTION_IF_NULL(new_output_shapes); | |||||
| auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); | |||||
| if (split_dim < 0) { | |||||
| split_dim += output_shape.size(); | |||||
| } | |||||
| output_shape[split_dim] = split_size; | |||||
| TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0); | |||||
| for (int64_t i = 0; i < num_split; ++i) { | |||||
| new_type_ids->emplace_back(type_id); | |||||
| new_output_shapes->emplace_back(output_shape); | |||||
| } | |||||
| } | |||||
| void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePtr &base_splitv, | |||||
| const std::vector<AnfNodePtr> &base_splitv_outputs, | |||||
| const std::vector<int64_t> &size_splits_base, int64_t split_dim, | |||||
| int64_t num_split) { | |||||
| SetAttrForSplitVNode(base_splitv, size_splits_base, split_dim, num_split); | |||||
| auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); | |||||
| TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0); | |||||
| std::vector<TypeId> base_type_ids(num_split, type_id); | |||||
| std::vector<std::vector<size_t>> base_output_shapes_base; | |||||
| if (split_dim < 0) { | |||||
| split_dim += output_shape.size(); | |||||
| } | |||||
| for (int64_t i = 0; i < num_split; ++i) { | |||||
| output_shape[split_dim] = size_splits_base[i]; | |||||
| base_output_shapes_base.emplace_back(output_shape); | |||||
| AnfAlgo::SetOutputInferTypeAndShape({type_id}, {output_shape}, base_splitv_outputs[i].get()); | |||||
| } | |||||
| AnfAlgo::SetOutputInferTypeAndShape(base_type_ids, base_output_shapes_base, base_splitv.get()); | |||||
| } | |||||
| AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int64_t num_split, int64_t divisor) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| auto split_dim = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrSplitDim); | |||||
| CNodePtr base_splitv = CreateBaseSplitVNode(func_graph, cnode); | |||||
| // Create new size_splits for "size_splits" attr of each new Splitv node which has full inputs. | |||||
| auto small_split_size = SizeToLong(GetSmallSplitSize(cnode, split_dim, num_split)); | |||||
| std::vector<int64_t> size_splits_new(divisor, small_split_size); | |||||
| // Create new output shape and new output type id for each new Splitv node which has full inputs. | |||||
| std::vector<TypeId> new_type_ids; | |||||
| std::vector<std::vector<size_t>> new_output_shapes; | |||||
| CreateOutputShapeAndTypeId(cnode, split_dim, small_split_size, divisor, &new_type_ids, &new_output_shapes); | |||||
| // Create make_tuple input to create a make_tuple for replacing the old Split node. | |||||
| std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; | |||||
| // Start to divide the outputs of Split. | |||||
| std::vector<int64_t> size_splits_base; | |||||
| std::vector<AnfNodePtr> base_splitv_outputs; | |||||
| const auto base_split_size = divisor * small_split_size; | |||||
| int64_t nodes_num = 0; | |||||
| int64_t cur_output_index = 0; | |||||
| while (num_split - cur_output_index > divisor) { | |||||
| auto tuple_getitem = CreateTupleGetItem(func_graph, base_splitv, nodes_num); | |||||
| base_splitv_outputs.push_back(tuple_getitem); | |||||
| CNodePtr new_splitv = CreateSplitVNode(func_graph, tuple_getitem); | |||||
| SetAttrForSplitVNode(new_splitv, size_splits_new, split_dim, divisor); | |||||
| AnfAlgo::SetOutputInferTypeAndShape(new_type_ids, new_output_shapes, new_splitv.get()); | |||||
| AddNewOutputs(func_graph, new_splitv, divisor, &make_tuple_inputs); | |||||
| cur_output_index += divisor; | |||||
| size_splits_base.emplace_back(base_split_size); | |||||
| nodes_num++; | |||||
| } | |||||
| if (cur_output_index < num_split) { | |||||
| auto last_node_num_split = num_split - cur_output_index; | |||||
| if (last_node_num_split > 1) { | |||||
| auto tuple_getitem = CreateTupleGetItem(func_graph, base_splitv, nodes_num); | |||||
| base_splitv_outputs.push_back(tuple_getitem); | |||||
| CNodePtr new_splitv = CreateSplitVNode(func_graph, tuple_getitem); | |||||
| std::vector<int64_t> size_splits_new_last(last_node_num_split, small_split_size); | |||||
| SetAttrForSplitVNode(new_splitv, size_splits_new_last, split_dim, last_node_num_split); | |||||
| // Create new output shape and new output type id for the last Splitv node | |||||
| std::vector<TypeId> last_new_type_ids; | |||||
| std::vector<std::vector<size_t>> last_new_output_shapes; | |||||
| CreateOutputShapeAndTypeId(cnode, split_dim, small_split_size, last_node_num_split, &last_new_type_ids, | |||||
| &last_new_output_shapes); | |||||
| AnfAlgo::SetOutputInferTypeAndShape(last_new_type_ids, last_new_output_shapes, new_splitv.get()); | |||||
| AddNewOutputs(func_graph, new_splitv, last_node_num_split, &make_tuple_inputs); | |||||
| size_splits_base.emplace_back(last_node_num_split * small_split_size); | |||||
| } else { | |||||
| auto tuple_getitem = CreateTupleGetItem(func_graph, base_splitv, nodes_num); | |||||
| base_splitv_outputs.push_back(tuple_getitem); | |||||
| make_tuple_inputs.emplace_back(tuple_getitem); | |||||
| size_splits_base.emplace_back(small_split_size); | |||||
| } | |||||
| nodes_num++; | |||||
| } | |||||
| // Set Attr and abstract for the base splitv | |||||
| SetAttrAndAbstractForBaseSplitv(cnode, base_splitv, base_splitv_outputs, size_splits_base, split_dim, nodes_num); | |||||
| AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); | |||||
| return make_tuple; | |||||
| } | |||||
| } // namespace | |||||
| const BaseRef SplitVFission::DefinePattern() const { | |||||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||||
| auto split_prim = std::make_shared<Primitive>(kSplitVOpName); | |||||
| return VectorRef({split_prim, Xs}); | |||||
| } | |||||
| const AnfNodePtr SplitVFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const EquivPtr &) const { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (AnfAlgo::IsDynamicShape(node)) { | |||||
| return nullptr; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| // Check output num | |||||
| if (!AnfAlgo::HasNodeAttr(kAttrNumSplit, cnode)) { | |||||
| return nullptr; | |||||
| } | |||||
| auto num_split = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrNumSplit); | |||||
| if (num_split <= outputs_divisor_) { | |||||
| return nullptr; | |||||
| } | |||||
| return DoFission(func_graph, cnode, num_split, outputs_divisor_); | |||||
| } | |||||
| } // namespace mindspore::opt | |||||
| @@ -0,0 +1,38 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_SPLITV_FISSION_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_SPLITV_FISSION_H_ | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class SplitVFission : public PatternProcessPass { | |||||
| const int kSplitOutputsDivisor = 63; | |||||
| public: | |||||
| explicit SplitVFission(bool multigraph = true) | |||||
| : PatternProcessPass("split_fission", multigraph), outputs_divisor_(kSplitOutputsDivisor) {} | |||||
| ~SplitVFission() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| private: | |||||
| int64_t outputs_divisor_; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_SPLITV_FISSION_H_ | |||||