| @@ -1 +1 @@ | |||||
| Subproject commit 20ecddee01cd07d0945240672597d7a36499e537 | |||||
| Subproject commit c63b2e6f7e7704f18b217e42c8c5c0b95e04b9fb | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -15,3 +15,4 @@ | |||||
| """init""" | """init""" | ||||
| from .splitter import split_with_json | from .splitter import split_with_json | ||||
| from .expander import get_op_expander | from .expander import get_op_expander | ||||
| from .parallel_estimate import estimate_calulation_amount, estimate_ops | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -16,3 +16,4 @@ | |||||
| from .graph_split import split | from .graph_split import split | ||||
| from .model_builder import GraphBuilder, load_composite | from .model_builder import GraphBuilder, load_composite | ||||
| from .graph_parallel import parallel_estimate | |||||
| @@ -0,0 +1,153 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # =========================================================================== | |||||
| """Cost model for parallel fusion""" | |||||
| from .model import PrimLib | |||||
| class ParalGain: | |||||
| def __init__(self, fusion_type, bottleneck, gain, block_assign): | |||||
| self.fusion_type = fusion_type | |||||
| self.bottleneck = bottleneck | |||||
| self.gain = gain | |||||
| self.block_assign = block_assign | |||||
| class ScheduleAnalyzer: | |||||
| """schedule analyzer""" | |||||
| WRAP_SIZE = 32 | |||||
| MAX_SM = 80 # Volta | |||||
| MAX_NUM_THREADS = 1024 | |||||
| MAX_BLOCK = 256 | |||||
| def __init__(self, graph): | |||||
| self.graph = graph | |||||
| self.block_num = 0 | |||||
| self.block_weight = 0 | |||||
| _, outputs = graph.deduce_parameters() | |||||
| self.ops = graph.ops | |||||
| self.dom_op = [out.op for out in outputs] | |||||
| def prod(self, shape): | |||||
| res = shape[0] | |||||
| for i in range(1, len(shape)): | |||||
| res = res * shape[i] | |||||
| return res | |||||
| def _cal_weight(self, ops): | |||||
| weight = 0 | |||||
| for op in ops: | |||||
| weight += self.prod(op.output.shape) * \ | |||||
| PrimLib.dtype_bytes(op.output.dtype) | |||||
| return weight | |||||
| def injective_analyze(self): | |||||
| """analyze injective case""" | |||||
| const_size = max([self.prod(op.output.shape) for op in self.dom_op]) | |||||
| const_size = (const_size + self.MAX_NUM_THREADS - | |||||
| 1) // self.MAX_NUM_THREADS * self.MAX_NUM_THREADS | |||||
| total_weight = self._cal_weight(self.ops) | |||||
| total_block = (const_size + self.MAX_NUM_THREADS - | |||||
| 1) // self.MAX_NUM_THREADS | |||||
| need_block_split = const_size > self.MAX_BLOCK * self.MAX_NUM_THREADS | |||||
| if need_block_split: | |||||
| self.block_num = self.MAX_BLOCK | |||||
| waves = (total_block + self.MAX_BLOCK - 1) // self.MAX_BLOCK | |||||
| self.block_weight = total_weight // total_block * waves | |||||
| else: | |||||
| self.block_num = total_block | |||||
| self.block_weight = total_weight // self.block_num | |||||
| def reduce_analyze(self): | |||||
| """analyze reduce case""" | |||||
| thread_x, thread_y = 32, 32 | |||||
| reduce_op = None | |||||
| for op in self.ops: | |||||
| if PrimLib.iter_type(op) == PrimLib.REDUCE: | |||||
| if reduce_op: | |||||
| raise RuntimeError( | |||||
| "Not support multiply reduce op in a graph now.") | |||||
| reduce_op = op | |||||
| if not reduce_op: | |||||
| raise RuntimeError("Wrong analyze for reduce!") | |||||
| shape = reduce_op.inputs[0].shape | |||||
| reduce_axis = reduce_op.attrs['reduce_axis'] | |||||
| total_space = self.prod(shape) | |||||
| red_space = shape[reduce_axis[0]] | |||||
| for i in range(1, len(reduce_axis)): | |||||
| red_space *= shape[reduce_axis[i]] | |||||
| dtype_size = PrimLib.dtype_bytes(reduce_op.output.dtype) | |||||
| weight = self._cal_weight(self.ops) # reduce + injective | |||||
| block_x = (total_space // red_space + thread_y - 1) // thread_y | |||||
| block_w = (weight + block_x - 1) // block_x | |||||
| waves = (block_x + self.MAX_BLOCK - 1) // self.MAX_BLOCK | |||||
| self.block_num = min(self.MAX_BLOCK, block_x) | |||||
| all_reduce = 10 # 1 reduce init + 3 sync + 5 bin + 1 write | |||||
| self.block_weight = (block_w + all_reduce * | |||||
| dtype_size * thread_x * thread_y) * waves | |||||
| def default_analyze(self): | |||||
| """analyze default case""" | |||||
| def _cal_default_space(op): | |||||
| space = self.prod(op.output.shape) | |||||
| for t in op.inputs: | |||||
| size = self.prod(t.shape) | |||||
| if size > space: | |||||
| space = size | |||||
| return space | |||||
| space = max([_cal_default_space(op) for op in self.dom_op]) | |||||
| # each sm least 4 wrap | |||||
| block = (space + (self.WRAP_SIZE * 4) - 1) // (self.WRAP_SIZE * 4) | |||||
| self.block_num = min(self.MAX_BLOCK, block) | |||||
| self.block_weight = self._cal_weight(self.ops) // self.block_num | |||||
| def analyze(self): | |||||
| """analyze ops""" | |||||
| def _ops_type(ops, dom_op): | |||||
| have_reduce = any( | |||||
| [PrimLib.iter_type(op) == PrimLib.REDUCE for op in ops]) | |||||
| if have_reduce: | |||||
| return True | |||||
| return PrimLib.iter_type(dom_op[0]) | |||||
| dom_type = _ops_type(self.ops, self.dom_op) | |||||
| if dom_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST): | |||||
| self.injective_analyze() | |||||
| elif dom_type == PrimLib.REDUCE: | |||||
| self.reduce_analyze() | |||||
| else: | |||||
| self.default_analyze() | |||||
| def block_parallel_estimate(graphs): | |||||
| """estimate block parallel gain""" | |||||
| sum_block, max_weight, sum_weight, blocks = 0, 0, 0, [] | |||||
| for g in graphs: | |||||
| s = ScheduleAnalyzer(g) | |||||
| s.analyze() | |||||
| sum_block += s.block_num | |||||
| if s.block_weight > max_weight: | |||||
| max_weight = s.block_weight | |||||
| sum_weight += s.block_weight | |||||
| blocks.append(s.block_num) | |||||
| if sum_block > ScheduleAnalyzer.MAX_SM * 32: | |||||
| return ParalGain("none", sum_weight, 0, []) | |||||
| return ParalGain("block_fusion", max_weight, sum_weight - max_weight, blocks) | |||||
| def parallel_estimate(graphs): | |||||
| return block_parallel_estimate(graphs) | |||||
| @@ -0,0 +1,49 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """estimate parallel case""" | |||||
| import json | |||||
| import json.decoder as jd | |||||
| import traceback | |||||
| from mindspore import log as logger | |||||
| from . import model | |||||
| def estimate_ops(json_str: str): | |||||
| """Call costmodel to estimate ops.""" | |||||
| try: | |||||
| json_obj = json.loads(json_str) | |||||
| graph_descs = json_obj["graph_desc"] | |||||
| graphs = [] | |||||
| for gd in graph_descs: | |||||
| graphs.append(model.load_composite(gd).graph) | |||||
| estimation = model.parallel_estimate(graphs) | |||||
| if estimation.fusion_type == "block_fusion" and estimation.gain > 0: | |||||
| res = (estimation.block_assign, estimation.gain) | |||||
| else: | |||||
| res = ([0 for g in graphs], 0) | |||||
| return res | |||||
| except jd.JSONDecodeError: | |||||
| logger.error(traceback.format_exc()) | |||||
| return None | |||||
| def estimate_calulation_amount(json_str: str): | |||||
| """Call costmodel to estimate calculation amount of op.""" | |||||
| try: | |||||
| graph_desc = json.loads(json_str) | |||||
| comp = model.load_composite(graph_desc) | |||||
| estimation = model.parallel_estimate([comp.graph]) | |||||
| return estimation.bottleneck | |||||
| except jd.JSONDecodeError: | |||||
| logger.error(traceback.format_exc()) | |||||
| return None | |||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -120,7 +120,7 @@ class OpInfoExtractor { | |||||
| } | } | ||||
| } | } | ||||
| if (op_attr->type().empty()) { | if (op_attr->type().empty()) { | ||||
| MS_LOG(DEBUG) << "Unknow type, ignore attr " << name; | |||||
| MS_LOG(DEBUG) << "Unknown type, ignore attr " << name; | |||||
| continue; | continue; | ||||
| } | } | ||||
| op_info->add_attrs_ptr(op_attr); | op_info->add_attrs_ptr(op_attr); | ||||
| @@ -174,7 +174,7 @@ bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, con | |||||
| // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input. | // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input. | ||||
| auto inputs_ptr = op_info->inputs_ptr(); | auto inputs_ptr = op_info->inputs_ptr(); | ||||
| if (inputs_ptr.empty()) { | if (inputs_ptr.empty()) { | ||||
| MS_LOG(ERROR) << "Kernel [" << anf_node->fullname_with_scope() << "] regist info has no input info"; | |||||
| MS_LOG(ERROR) << "Kernel [" << anf_node->fullname_with_scope() << "] info has no input info"; | |||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -184,7 +184,7 @@ bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, con | |||||
| for (size_t i = 0; i < inputs_ptr.size(); i++) { | for (size_t i = 0; i < inputs_ptr.size(); i++) { | ||||
| auto input_ptr = inputs_ptr[i]; | auto input_ptr = inputs_ptr[i]; | ||||
| if (input_ptr == nullptr) { | if (input_ptr == nullptr) { | ||||
| MS_LOG(ERROR) << "Kernel [" << anf_node->fullname_with_scope() << "] regist input[" << i << "] is nullptr"; | |||||
| MS_LOG(ERROR) << "Kernel [" << anf_node->fullname_with_scope() << "] input[" << i << "] is nullptr"; | |||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -204,7 +204,8 @@ bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, con | |||||
| input_desc_json[kJsonKeyName] = input_ptr->name(); | input_desc_json[kJsonKeyName] = input_ptr->name(); | ||||
| input_desc_json[kJsonKeyTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index)); | input_desc_json[kJsonKeyTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index)); | ||||
| auto input_shape = this->GetInputShape(anf_node, real_input_index); | auto input_shape = this->GetInputShape(anf_node, real_input_index); | ||||
| if (AnfAlgo::IsNodeInGraphKernel(anf_node) && GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) { | |||||
| if (dump_option_.extract_opinfo_from_anfnode && | |||||
| GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) { | |||||
| MS_LOG(DEBUG) << "Take input[" << real_input_index << "] of [" << anf_node->DebugString(2) | MS_LOG(DEBUG) << "Take input[" << real_input_index << "] of [" << anf_node->DebugString(2) | ||||
| << "] as const tensor, shape: [" << Vector2Str(input_shape) | << "] as const tensor, shape: [" << Vector2Str(input_shape) | ||||
| << "], value: " << input_desc_json[kJsonKeyValue]; | << "], value: " << input_desc_json[kJsonKeyValue]; | ||||
| @@ -555,6 +556,30 @@ bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node, nlohmann::j | |||||
| return true; | return true; | ||||
| } | } | ||||
| void AkgKernelJsonGenerator::SetParallelValueToJson(const std::string &processor, | |||||
| const std::map<size_t, size_t> &dim_infos, | |||||
| nlohmann::json *sub_fusion_json) { | |||||
| if (processor == kProcessorCuda) { | |||||
| std::vector<size_t> cnums; | |||||
| std::transform(dim_infos.cbegin(), dim_infos.cend(), std::back_insert_iterator(cnums), | |||||
| [](const std::pair<size_t, size_t> &dim) { return dim.second; }); | |||||
| (*sub_fusion_json)[kJsonKeyCoreNum] = cnums; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Parallel fusion not support " << processor << " now."; | |||||
| } | |||||
| } | |||||
| void AkgKernelJsonGenerator::AddParalleFusionJsonInfo(const std::string &processor, nlohmann::json *kernel_json) { | |||||
| nlohmann::json parallel_fusion_json; | |||||
| parallel_fusion_json[kJsonKeyFusionType] = "block_fusion"; | |||||
| std::vector<std::vector<std::string>> sgraphs; | |||||
| std::transform(sub_graphs_.cbegin(), sub_graphs_.cend(), std::back_insert_iterator(sgraphs), | |||||
| [](const std::pair<int, std::vector<std::string>> &sg) { return sg.second; }); | |||||
| parallel_fusion_json[kJsonKeySubGraph] = sgraphs; | |||||
| SetParallelValueToJson(processor, dim_infos_, ¶llel_fusion_json); | |||||
| (*kernel_json)[kJsonKeyParallelFusion] = parallel_fusion_json; | |||||
| } | |||||
| bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes, | bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes, | ||||
| const std::vector<AnfNodePtr> &input_list, | const std::vector<AnfNodePtr> &input_list, | ||||
| const std::vector<AnfNodePtr> &output_list, nlohmann::json *kernel_json) { | const std::vector<AnfNodePtr> &output_list, nlohmann::json *kernel_json) { | ||||
| @@ -581,6 +606,13 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf | |||||
| (*kernel_json)[kJsonKeyOutputDesc] = | (*kernel_json)[kJsonKeyOutputDesc] = | ||||
| CreateOutputsJson(anf_nodes, input_list, output_list, inputs_json, node_json_map); | CreateOutputsJson(anf_nodes, input_list, output_list, inputs_json, node_json_map); | ||||
| auto processor = GetProcessorStr(anf_nodes[0]); | |||||
| // Add parallel fusion information. | |||||
| if (!sub_graphs_.empty()) { | |||||
| AddParalleFusionJsonInfo(processor, kernel_json); | |||||
| } | |||||
| size_t hash_id = std::hash<std::string>()(kernel_json->dump()); | size_t hash_id = std::hash<std::string>()(kernel_json->dump()); | ||||
| kernel_name_ = "Fused_"; | kernel_name_ = "Fused_"; | ||||
| auto fg = anf_nodes[0]->func_graph(); | auto fg = anf_nodes[0]->func_graph(); | ||||
| @@ -601,7 +633,7 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf | |||||
| (*kernel_json)[kJsonKeyId] = GetOpCntInc(); | (*kernel_json)[kJsonKeyId] = GetOpCntInc(); | ||||
| (*kernel_json)[kJsonKeyOp] = kernel_name_; | (*kernel_json)[kJsonKeyOp] = kernel_name_; | ||||
| (*kernel_json)[kJsonKeyPlatform] = "AKG"; | (*kernel_json)[kJsonKeyPlatform] = "AKG"; | ||||
| (*kernel_json)[kJsonKeyProcess] = GetProcessorStr(anf_nodes[0]); | |||||
| (*kernel_json)[kJsonKeyProcess] = processor; | |||||
| (*kernel_json)[kJsonKeyComposite] = true; | (*kernel_json)[kJsonKeyComposite] = true; | ||||
| (*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString() + "." + fg->debug_info()->get_id(); | (*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString() + "." + fg->debug_info()->get_id(); | ||||
| @@ -724,6 +756,17 @@ nlohmann::json AkgKernelJsonGenerator::CreateOutputsJson(const std::vector<AnfNo | |||||
| output_shape.push_back(1); | output_shape.push_back(1); | ||||
| } | } | ||||
| output_desc_json[kJsonKeyShape] = output_shape; | output_desc_json[kJsonKeyShape] = output_shape; | ||||
| if (auto tcnode = tmp_output.first->cast<CNodePtr>(); | |||||
| tcnode && AnfAlgo::HasNodeAttr(kAttrParallelDimInfo, tcnode)) { | |||||
| auto info = AnfAlgo::GetNodeAttr<std::vector<size_t>>(tcnode, kAttrParallelDimInfo); | |||||
| if (info.size() != 2) { | |||||
| MS_LOG(EXCEPTION) << "Parallel dim info is invalid!"; | |||||
| } | |||||
| sub_graphs_[info[0]].push_back(output_desc_json[kJsonKeyTensorName]); | |||||
| if (dim_infos_.find(info[0]) == dim_infos_.end()) { | |||||
| dim_infos_[info[0]] = info[1]; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| outputs_json.emplace_back(output_desc_json); | outputs_json.emplace_back(output_desc_json); | ||||
| } | } | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -49,6 +49,11 @@ constexpr auto kJsonKeyPtrAddress = "ptr_address"; | |||||
| constexpr auto kJsonKeyCompositeGraph = "composite_graph"; | constexpr auto kJsonKeyCompositeGraph = "composite_graph"; | ||||
| constexpr auto kJsonKeyPlatform = "platform"; | constexpr auto kJsonKeyPlatform = "platform"; | ||||
| constexpr auto kJsonKeyOpFullName = "op_full_name"; | constexpr auto kJsonKeyOpFullName = "op_full_name"; | ||||
| constexpr auto kJsonKeyFusion = "fusion"; | |||||
| constexpr auto kJsonKeyParallelFusion = "parallel_fusion"; | |||||
| constexpr auto kJsonKeyFusionType = "fusion_type"; | |||||
| constexpr auto kJsonKeySubGraph = "sub_graph"; | |||||
| constexpr auto kJsonKeyCoreNum = "core_num"; | |||||
| constexpr auto kAttrInputNames = "input_names"; | constexpr auto kAttrInputNames = "input_names"; | ||||
| @@ -81,6 +86,8 @@ class AkgKernelJsonGenerator { | |||||
| input_tensor_idx_.clear(); | input_tensor_idx_.clear(); | ||||
| address_node_map_.clear(); | address_node_map_.clear(); | ||||
| output_tensor_idx_ = 0; | output_tensor_idx_ = 0; | ||||
| sub_graphs_.clear(); | |||||
| dim_infos_.clear(); | |||||
| } | } | ||||
| void set_dump_option(DumpOption dump_option) { dump_option_ = dump_option; } | void set_dump_option(DumpOption dump_option) { dump_option_ = dump_option; } | ||||
| std::map<std::string, AnfNodePtr> address_node_map() { return address_node_map_; } | std::map<std::string, AnfNodePtr> address_node_map() { return address_node_map_; } | ||||
| @@ -115,6 +122,9 @@ class AkgKernelJsonGenerator { | |||||
| std::string GetOutputFormat(const AnfNodePtr &anf_node, size_t index); | std::string GetOutputFormat(const AnfNodePtr &anf_node, size_t index); | ||||
| void SaveNodeAddress(const AnfNodePtr &anf_node, nlohmann::json *node_json); | void SaveNodeAddress(const AnfNodePtr &anf_node, nlohmann::json *node_json); | ||||
| OpInfoPtr ExtractOpInfo(const AnfNodePtr &anf_node); | OpInfoPtr ExtractOpInfo(const AnfNodePtr &anf_node); | ||||
| void SetParallelValueToJson(const std::string &processor, const std::map<size_t, size_t> &dim_infos, | |||||
| nlohmann::json *sub_fusion_json); | |||||
| void AddParalleFusionJsonInfo(const std::string &processor, nlohmann::json *kernel_json); | |||||
| DumpOption dump_option_; | DumpOption dump_option_; | ||||
| static int op_cnt_; | static int op_cnt_; | ||||
| @@ -127,6 +137,8 @@ class AkgKernelJsonGenerator { | |||||
| std::vector<size_t> input_size_list_; | std::vector<size_t> input_size_list_; | ||||
| std::vector<size_t> output_size_list_; | std::vector<size_t> output_size_list_; | ||||
| std::map<std::string, AnfNodePtr> address_node_map_; | std::map<std::string, AnfNodePtr> address_node_map_; | ||||
| std::map<size_t, std::vector<std::string>> sub_graphs_; | |||||
| std::map<size_t, size_t> dim_infos_; | |||||
| }; | }; | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -133,8 +133,10 @@ bool AtomicCleanInsertter::CanActivateAtomicAdd(const AnfNodePtr &anf_node) { | |||||
| if (reduce_cnt != 1) { | if (reduce_cnt != 1) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| real_output_num_ = inputs.size() - 1; | |||||
| } else if (IsPrimitiveCNode(real_return_node, prim::kPrimReduceSum)) { | } else if (IsPrimitiveCNode(real_return_node, prim::kPrimReduceSum)) { | ||||
| atomic_add_node_ = real_return_node->cast<CNodePtr>(); | atomic_add_node_ = real_return_node->cast<CNodePtr>(); | ||||
| real_output_num_ = 1; | |||||
| } else { | } else { | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -200,7 +202,6 @@ void AtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn(const FuncGra | |||||
| auto retrun_node = sub_graph->get_return()->input(kFirstDataInputIndex); | auto retrun_node = sub_graph->get_return()->input(kFirstDataInputIndex); | ||||
| if (IsPrimitiveCNode(retrun_node, prim::kPrimMakeTuple)) { | if (IsPrimitiveCNode(retrun_node, prim::kPrimMakeTuple)) { | ||||
| const auto &outs = retrun_node->cast<CNodePtr>()->inputs(); | const auto &outs = retrun_node->cast<CNodePtr>()->inputs(); | ||||
| real_output_num_ = outs.size() - 1; | |||||
| for (size_t i = 1; i < outs.size(); ++i) { | for (size_t i = 1; i < outs.size(); ++i) { | ||||
| if (i != reduce_real_output_index_ + 1) { | if (i != reduce_real_output_index_ + 1) { | ||||
| out_node = outs[i]; | out_node = outs[i]; | ||||
| @@ -209,7 +210,6 @@ void AtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn(const FuncGra | |||||
| } | } | ||||
| } | } | ||||
| } else { | } else { | ||||
| real_output_num_ = 1; | |||||
| out_node = atomic_add_node_; // Use result data itself, and set attr "fake_out" true. | out_node = atomic_add_node_; // Use result data itself, and set attr "fake_out" true. | ||||
| fake_out = true; | fake_out = true; | ||||
| } | } | ||||
| @@ -456,7 +456,7 @@ std::vector<std::pair<AnfNodePtr, int> > AtomicCleanInsertter::FindOriginCNodeUs | |||||
| } | } | ||||
| } | } | ||||
| for (auto &pair : getitem_user_nodes) { | for (auto &pair : getitem_user_nodes) { | ||||
| // dirctory to find real user. | |||||
| // Directory to find real user. | |||||
| auto real_users = mng->node_users()[pair.first]; | auto real_users = mng->node_users()[pair.first]; | ||||
| reduce_user_nodes.insert(reduce_user_nodes.end(), real_users.begin(), real_users.end()); | reduce_user_nodes.insert(reduce_user_nodes.end(), real_users.begin(), real_users.end()); | ||||
| } | } | ||||
| @@ -0,0 +1,155 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/graph_kernel/depend_formater.h" | |||||
| #include <tuple> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "backend/kernel_compiler/common_utils.h" | |||||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| bool RemoveRedundantDepend(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) { | |||||
| const auto &users = mng->node_users()[node]; | |||||
| std::vector<std::pair<AnfNodePtr, int>> sons; | |||||
| for (const auto &[user, index] : users) { | |||||
| if (!IsPrimitiveCNode(user, prim::kPrimTupleGetItem)) { | |||||
| sons.emplace_back(user, index); | |||||
| continue; | |||||
| } | |||||
| auto &[fake_first_grad_son, grad_index] = *((mng->node_users()[user]).begin()); | |||||
| sons.emplace_back(fake_first_grad_son, grad_index); | |||||
| } | |||||
| AnfNodePtrList latter_to_delete; | |||||
| for (const auto &[son, index] : sons) { | |||||
| if (!IsPrimitiveCNode(son, prim::kPrimDepend) || index != kDependAttachNodeIndex) { | |||||
| continue; | |||||
| } | |||||
| latter_to_delete.push_back(son); | |||||
| } | |||||
| if (latter_to_delete.empty()) { | |||||
| return false; | |||||
| } | |||||
| std::vector<AnfNodePtr>::iterator delete_begin = latter_to_delete.begin(); | |||||
| if (latter_to_delete.size() == sons.size()) { | |||||
| // Left one Depend node relation and delete others! | |||||
| ++delete_begin; | |||||
| } | |||||
| for (; delete_begin != latter_to_delete.end(); ++delete_begin) { | |||||
| auto depend_anfnode = *delete_begin; | |||||
| auto depend_cnode = depend_anfnode->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(depend_cnode); | |||||
| auto depend_prior_node = depend_cnode->input(kRealInputIndexInDepend); | |||||
| mng->Replace(depend_anfnode, depend_prior_node); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| AnfNodePtr FindPatronNode(const FuncGraphPtr &main_graph, const FuncGraphManagerPtr &mng) { | |||||
| AnfNodePtr patron_node; | |||||
| auto return_cnode = main_graph->get_return()->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(return_cnode); | |||||
| auto output_node = return_cnode->input(kFirstDataInputIndex); | |||||
| if (IsPrimitiveCNode(output_node, prim::kPrimMakeTuple)) { | |||||
| auto output_cnode = output_node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(output_cnode); | |||||
| patron_node = output_cnode->input(kFirstDataInputIndex); | |||||
| } else { | |||||
| patron_node = output_node; | |||||
| } | |||||
| return patron_node; | |||||
| } | |||||
| void AddDepends(const AnfNodePtr &stable_node, const AnfNodePtrList &free_nodes, const FuncGraphPtr &main_graph, | |||||
| const FuncGraphManagerPtr &mng) { | |||||
| AnfNodePtr modified_node = stable_node; | |||||
| for (const auto &free_node : free_nodes) { | |||||
| AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), modified_node, free_node}; | |||||
| auto depend_cnode = main_graph->NewCNode(d_inputs); | |||||
| depend_cnode->set_abstract(modified_node->abstract()); | |||||
| main_graph->AddNode(depend_cnode); | |||||
| modified_node = depend_cnode; | |||||
| } | |||||
| if (!free_nodes.empty()) { | |||||
| mng->Replace(stable_node, modified_node); | |||||
| } | |||||
| } | |||||
| } // namespace | |||||
| bool DependFormater::Run(const FuncGraphPtr &func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| auto mng = func_graph->manager(); | |||||
| if (mng == nullptr) { | |||||
| mng = Manage(func_graph, true); | |||||
| func_graph->set_manager(mng); | |||||
| } | |||||
| // 1. Try to remove redundant depend. | |||||
| bool changed = false; | |||||
| auto nodes = TopoSort(func_graph->get_return()); | |||||
| std::for_each(nodes.rbegin(), nodes.rend(), [&changed, &mng](const AnfNodePtr &node) { | |||||
| if (RemoveRedundantDepend(node, mng)) { | |||||
| changed = true; | |||||
| } | |||||
| }); | |||||
| // Should re-toposort for changed graph. | |||||
| if (changed) { | |||||
| nodes = TopoSort(func_graph->get_return()); | |||||
| } | |||||
| // 2. Move depend to tail of graph. | |||||
| AnfNodePtrList old_depends; | |||||
| AnfNodePtrList free_nodes; | |||||
| // Find depend and its free nodes. | |||||
| for (const auto &node : nodes) { | |||||
| if (!IsPrimitiveCNode(node, prim::kPrimDepend)) { | |||||
| continue; | |||||
| } | |||||
| old_depends.push_back(node); | |||||
| free_nodes.push_back(node->cast<CNodePtr>()->input(kDependAttachNodeIndex)); | |||||
| } | |||||
| if (old_depends.empty()) { | |||||
| return changed; | |||||
| } | |||||
| // Delete old depend. | |||||
| for (const auto &depend_anfnode : old_depends) { | |||||
| auto depend_cnode = depend_anfnode->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(depend_cnode); | |||||
| auto depend_prior_node = depend_cnode->input(kControlDependPriorIndex); | |||||
| mng->Replace(depend_anfnode, depend_prior_node); | |||||
| } | |||||
| // Add new depend node in tail. | |||||
| AnfNodePtr patron_node = FindPatronNode(func_graph, mng); | |||||
| AddDepends(patron_node, free_nodes, func_graph, mng); | |||||
| return true; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_DEPEND_FORMATER_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_DEPEND_FORMATER_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include "backend/optimizer/common/pass.h" | |||||
| #include "ir/func_graph.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class DependFormater : public Pass { | |||||
| public: | |||||
| DependFormater() : Pass("depend_formater") {} | |||||
| ~DependFormater() override = default; | |||||
| bool Run(const FuncGraphPtr &graph) override; | |||||
| }; | |||||
| using DependFormaterPtr = std::shared_ptr<DependFormater>; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_DEPEND_FORMATER_H_ | |||||
| @@ -274,7 +274,7 @@ bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *i | |||||
| MS_EXCEPTION_IF_NULL(inputs_ptr); | MS_EXCEPTION_IF_NULL(inputs_ptr); | ||||
| auto nodes = TopoSort(fg->get_return()); | auto nodes = TopoSort(fg->get_return()); | ||||
| std::map<ValuePtr, AnfNodePtrList> vmap; | |||||
| OrderedMap<ValuePtr, AnfNodePtrList> vmap; | |||||
| for (const auto &node : nodes) { | for (const auto &node : nodes) { | ||||
| if (!node->isa<CNode>()) { | if (!node->isa<CNode>()) { | ||||
| continue; | continue; | ||||
| @@ -590,7 +590,7 @@ bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, n | |||||
| op_nodes = nodes; | op_nodes = nodes; | ||||
| } else { | } else { | ||||
| // When there are basic and composite ops, the composite ops should be inline to the basic ones' graph, | // When there are basic and composite ops, the composite ops should be inline to the basic ones' graph, | ||||
| // so a new graph generation should be done (beacuse they may in the main graph!). | |||||
| // so a new graph generation should be done (because they may in the main graph!). | |||||
| // If address_node_map is wanted, we should map the new node in new graph to the old nodes. But... not support now. | // If address_node_map is wanted, we should map the new node in new graph to the old nodes. But... not support now. | ||||
| MS_LOG(EXCEPTION) << "No support mixed with basic and composite ops now!"; | MS_LOG(EXCEPTION) << "No support mixed with basic and composite ops now!"; | ||||
| } | } | ||||
| @@ -1016,5 +1016,16 @@ CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr & | |||||
| func_graph->AddNode(cnode); | func_graph->AddNode(cnode); | ||||
| return cnode; | return cnode; | ||||
| } | } | ||||
| void MakeCNodeSafeForAttr(const AnfNodePtr &node) { | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| if (cnode == nullptr) { | |||||
| return; | |||||
| } | |||||
| AnfNodePtrList new_inputs = {NewValueNode(AnfAlgo::GetCNodePrimitive(cnode)->Clone())}; | |||||
| auto inputs = cnode->inputs(); | |||||
| new_inputs.insert(new_inputs.end(), inputs.begin() + 1, inputs.end()); | |||||
| cnode->set_inputs(new_inputs); | |||||
| } | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -42,6 +42,8 @@ using kernel::DumpOption; | |||||
| constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput"; | constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput"; | ||||
| constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList"; | constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList"; | ||||
| constexpr auto kGraphKernelModule = "mindspore._extends.graph_kernel"; | constexpr auto kGraphKernelModule = "mindspore._extends.graph_kernel"; | ||||
| constexpr auto kGraphKernelEstimateOps = "estimate_ops"; | |||||
| constexpr auto kGraphKernelGetNodeCalAmount = "estimate_calulation_amount"; | |||||
| constexpr auto kGraphKernelSplitFunc = "split_with_json"; | constexpr auto kGraphKernelSplitFunc = "split_with_json"; | ||||
| constexpr auto kGetGraphKernelOpExpander = "get_op_expander"; | constexpr auto kGetGraphKernelOpExpander = "get_op_expander"; | ||||
| constexpr auto kJsonKeyMultiGraph = "multi_graph"; | constexpr auto kJsonKeyMultiGraph = "multi_graph"; | ||||
| @@ -88,6 +90,7 @@ ShapeVector GetShape(const AnfNodePtr &node); | |||||
| std::vector<int64_t> GetReduceAxis(const AnfNodePtr &node); | std::vector<int64_t> GetReduceAxis(const AnfNodePtr &node); | ||||
| CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info); | CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info); | ||||
| void MakeCNodeSafeForAttr(const AnfNodePtr &node); | |||||
| template <typename T> | template <typename T> | ||||
| ValueNodePtr CreateScalarTensorValueNode(const DataInfo &info, T value, size_t data_length) { | ValueNodePtr CreateScalarTensorValueNode(const DataInfo &info, T value, size_t data_length) { | ||||
| @@ -0,0 +1,89 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/graph_kernel/parallel_cost_model.h" | |||||
| #include <algorithm> | |||||
| #include "backend/kernel_compiler/akg/akg_kernel_json_generator.h" | |||||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | |||||
| #include "pipeline/jit/parse/python_adapter.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| std::string CommonDimInfo::ToString() { | |||||
| std::ostringstream buffer; | |||||
| buffer << "Dim(" << dim_info_ << ")"; | |||||
| return buffer.str(); | |||||
| } | |||||
| int ParallelCostModel::GetNodeCalAmount(const AnfNodePtr &node) { | |||||
| nlohmann::json json_desc; | |||||
| AnfNodePtrList nodes = {node}; | |||||
| DumpOption dump_option; | |||||
| if (!AnfToJsonDesc(nodes, dump_option, &json_desc)) { | |||||
| MS_LOG(EXCEPTION) << "Collect json desc failed."; | |||||
| } | |||||
| auto json_desc_str = json_desc.dump(); | |||||
| auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGraphKernelGetNodeCalAmount, json_desc_str); | |||||
| if (py::isinstance<py::none>(ret)) { | |||||
| MS_LOG(EXCEPTION) << "CallPyFn: [" << kGraphKernelSplitFunc << "] return invalid result. input json:\n" | |||||
| << json_desc_str; | |||||
| } | |||||
| return py::cast<int>(ret); | |||||
| } | |||||
| std::tuple<std::vector<DimInfoPtr>, int> ParallelCostModel::CalFuseInfo(const AnfNodePtrList &nodes) { | |||||
| nlohmann::json json_desc; | |||||
| std::vector<AnfNodePtrList> graphs; | |||||
| std::transform(nodes.begin(), nodes.end(), std::back_inserter(graphs), | |||||
| [](const AnfNodePtr &node) -> AnfNodePtrList { return {node}; }); | |||||
| DumpOption dump_option; | |||||
| if (!AnfToJsonDesc(graphs, dump_option, &json_desc)) { | |||||
| MS_LOG(EXCEPTION) << "Collect json desc failed."; | |||||
| } | |||||
| auto json_desc_str = json_desc.dump(); | |||||
| auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGraphKernelEstimateOps, json_desc_str); | |||||
| if (py::isinstance<py::none>(ret)) { | |||||
| MS_LOG(EXCEPTION) << "CallPyFn: [" << kGraphKernelSplitFunc << "] return invalid result. input json:\n" | |||||
| << json_desc_str; | |||||
| } | |||||
| py::tuple ret_tuple = py::cast<py::tuple>(ret); | |||||
| if (!py::isinstance<py::tuple>(ret_tuple) || ret_tuple.size() != 2) { | |||||
| MS_LOG(EXCEPTION) << "Parallel cost model should return a tuple with two elements!"; | |||||
| } | |||||
| std::vector<DimInfoPtr> dim_infos; | |||||
| py::list dim_list = py::cast<py::list>(ret_tuple[0]); | |||||
| for (size_t i = 0; i < dim_list.size(); ++i) { | |||||
| dim_infos.push_back(std::make_shared<CommonDimInfo>(py::cast<int>(dim_list[i]))); | |||||
| } | |||||
| int benefit = py::cast<int>(ret_tuple[1]); | |||||
| return std::make_tuple(dim_infos, benefit); | |||||
| } | |||||
| ParallelCostModelPtr ParellelCostModelWarehouse::GetParallelCostModel(const std::string &target) { | |||||
| if (target != kGPUDevice) { | |||||
| MS_LOG(EXCEPTION) << "Parallel cost model only support " << kGPUDevice << " now."; | |||||
| } | |||||
| return cost_model_; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,82 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_COST_MODEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_COST_MODEL_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include <sstream> | |||||
| #include <string> | |||||
| #include <tuple> | |||||
| #include <vector> | |||||
| #include "base/base.h" | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| #include "backend/optimizer/graph_kernel/parallel_cost_model.h" | |||||
| #include "backend/session/kernel_graph.h" | |||||
| #include "utils/ms_context.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class DimInfo { | |||||
| public: | |||||
| DimInfo() = default; | |||||
| ~DimInfo() {} | |||||
| virtual std::string ToString() = 0; | |||||
| }; | |||||
| class CommonDimInfo : public DimInfo { | |||||
| public: | |||||
| explicit CommonDimInfo(size_t dim) : dim_info_(dim) {} | |||||
| ~CommonDimInfo() {} | |||||
| void set_dim_info(size_t d) { dim_info_ = d; } | |||||
| size_t dim_info() const { return dim_info_; } | |||||
| std::string ToString() override; | |||||
| private: | |||||
| size_t dim_info_; | |||||
| }; | |||||
| using DimInfoPtr = std::shared_ptr<DimInfo>; | |||||
| using CommonDimInfoPtr = std::shared_ptr<CommonDimInfo>; | |||||
| class ParallelCostModel { | |||||
| public: | |||||
| ParallelCostModel() {} | |||||
| ~ParallelCostModel() {} | |||||
| int GetNodeCalAmount(const AnfNodePtr &node); | |||||
| std::tuple<std::vector<DimInfoPtr>, int> CalFuseInfo(const AnfNodePtrList &nodes); | |||||
| }; | |||||
| using ParallelCostModelPtr = std::shared_ptr<ParallelCostModel>; | |||||
| class ParellelCostModelWarehouse { | |||||
| public: | |||||
| static ParellelCostModelWarehouse &Instance() { | |||||
| static ParellelCostModelWarehouse instance; | |||||
| return instance; | |||||
| } | |||||
| ParallelCostModelPtr GetParallelCostModel(const std::string &target); | |||||
| private: | |||||
| ParellelCostModelWarehouse() { cost_model_ = std::make_shared<ParallelCostModel>(); } | |||||
| ParallelCostModelPtr cost_model_; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_COST_MODEL_H_ | |||||
| @@ -0,0 +1,876 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/graph_kernel/parallel_fusion.h" | |||||
| #include <algorithm> | |||||
| #include <list> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <queue> | |||||
| #include <set> | |||||
| #include <sstream> | |||||
| #include <stack> | |||||
| #include <string> | |||||
| #include <tuple> | |||||
| #include <unordered_map> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include <cstdlib> | |||||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | |||||
| #include "frontend/operator/ops.h" | |||||
| #include "ir/func_graph_cloner.h" | |||||
| #include "vm/segment_runner.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| bool IsOneOf(const AnfNodePtr &node, const std::vector<PrimitivePtr> &ops_prim) { | |||||
| return std::any_of(ops_prim.cbegin(), ops_prim.cend(), | |||||
| [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); | |||||
| } | |||||
| void ProcessThroughPassCNode(std::function<bool(const AnfNodePtr &)> pass_fn, | |||||
| OrderedMap<AnfNodePtr, NodeRelation> *node_rels) { | |||||
| std::set<AnfNodePtr> latter_to_be_erased; | |||||
| for (const auto &[node, node_rel] : (*node_rels)) { | |||||
| if (!pass_fn(node) || latter_to_be_erased.count(node) != 0) { | |||||
| continue; | |||||
| } | |||||
| auto nexts = node_rel.nexts; | |||||
| std::vector<AnfNodePtr> pre_nodes; | |||||
| std::queue<AnfNodePtr> node_que; | |||||
| node_que.push(node); | |||||
| // Find until all pre nodes get false from pass_fn, and collect all these predecessor nodes. | |||||
| while (!node_que.empty()) { | |||||
| auto cur_node = node_que.front(); | |||||
| node_que.pop(); | |||||
| if (!pass_fn(cur_node)) { | |||||
| pre_nodes.push_back(cur_node); | |||||
| continue; | |||||
| } | |||||
| latter_to_be_erased.insert(cur_node); | |||||
| auto predecessors = (*node_rels)[cur_node].pres; | |||||
| if (predecessors.empty()) { | |||||
| continue; | |||||
| } | |||||
| for (const auto &pre_node : predecessors) { | |||||
| (*node_rels)[cur_node].pres.erase(pre_node); | |||||
| (*node_rels)[pre_node].nexts.erase(cur_node); | |||||
| node_que.push(pre_node); | |||||
| } | |||||
| } | |||||
| // Modify the relation: delete node <-> next_node, add pre node <-> next_node. | |||||
| for (const auto &next_node : nexts) { | |||||
| (*node_rels)[next_node].pres.erase(node); | |||||
| for (const auto &cur_node : pre_nodes) { | |||||
| (*node_rels)[next_node].pres.insert(cur_node); | |||||
| (*node_rels)[cur_node].nexts.insert(next_node); | |||||
| } | |||||
| } | |||||
| } | |||||
| for (const auto &node : latter_to_be_erased) { | |||||
| node_rels->erase(node); | |||||
| } | |||||
| } | |||||
| void ProcessDependCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) { | |||||
| for (auto &[node, node_rel] : (*node_rels)) { | |||||
| if (!IsPrimitiveCNode(node, prim::kPrimDepend)) { | |||||
| continue; | |||||
| } | |||||
| // Make attached nodes deattach with node. | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| for (size_t id = kDependAttachNodeIndex; id < cnode->inputs().size(); ++id) { | |||||
| auto attach_node = cnode->input(id); | |||||
| if (auto iter = node_rels->find(attach_node); iter != node_rels->end()) { | |||||
| iter->second.nexts.erase(node); | |||||
| } | |||||
| if (auto &cnode_pres = node_rel.pres; cnode_pres.count(attach_node) != 0) { | |||||
| cnode_pres.erase(attach_node); | |||||
| } | |||||
| } | |||||
| } | |||||
| // Eliminate depend node of node relations. | |||||
| ProcessThroughPassCNode([](const AnfNodePtr &node) { return IsOneOf(node, {prim::kPrimDepend}); }, node_rels); | |||||
| } | |||||
| std::tuple<std::pair<AnfNodePtr, AnfNodePtr>, std::pair<AnfNodePtrList, AnfNodePtrList>> FindRelationOfControlDepend( | |||||
| const AnfNodePtr &node, OrderedMap<AnfNodePtr, NodeRelation> *node_rels) { | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| auto prior_node = cnode->input(kControlDependPriorIndex); | |||||
| auto behind_node = cnode->input(kControlDependBehindIndex); | |||||
| MS_EXCEPTION_IF_NULL(prior_node); | |||||
| MS_EXCEPTION_IF_NULL(behind_node); | |||||
| OrderedSet<AnfNodePtr> prior_nodes; | |||||
| prior_nodes.insert(prior_node); | |||||
| OrderedSet<AnfNodePtr> behind_nodes; | |||||
| behind_nodes.insert(behind_node); | |||||
| int64_t depend_mode = 0; | |||||
| if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) { | |||||
| depend_mode = AnfAlgo::GetNodeAttr<int64_t>(cnode, kControlDependMode); | |||||
| } | |||||
| if (prior_node->isa<Parameter>() && depend_mode == 1) { | |||||
| prior_nodes = (*node_rels)[prior_node].nexts; | |||||
| } | |||||
| if (behind_node->isa<Parameter>()) { | |||||
| behind_nodes = depend_mode == 1 ? (*node_rels)[behind_node].nexts : OrderedSet<AnfNodePtr>(); | |||||
| } | |||||
| // Get real nodes. | |||||
| AnfNodePtrList real_prior_nodes; | |||||
| std::set<AnfNodePtr> prior_visited; | |||||
| for (const auto &tmp : prior_nodes) { | |||||
| AnfAlgo::GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited); | |||||
| } | |||||
| AnfNodePtrList real_behind_nodes; | |||||
| std::set<AnfNodePtr> behind_visited; | |||||
| for (const auto &tmp : behind_nodes) { | |||||
| AnfAlgo::GetAllFatherRealNode(tmp, &real_behind_nodes, &behind_visited); | |||||
| } | |||||
| return std::make_tuple(std::make_pair(prior_node, behind_node), std::make_pair(real_prior_nodes, real_behind_nodes)); | |||||
| } | |||||
| void ReLinkNodesOfControlDependByRelation(const std::unordered_map<AnfNodePtr, AnfNodePtrList> &control_depend_info, | |||||
| OrderedMap<AnfNodePtr, NodeRelation> *node_rels) { | |||||
| // Relink and its log. | |||||
| for (const auto &m : control_depend_info) { | |||||
| const auto &prior = m.second[0]; | |||||
| const auto &behind = m.second[1]; | |||||
| (*node_rels)[prior].nexts.insert(behind); | |||||
| (*node_rels)[behind].pres.insert(prior); | |||||
| MS_LOG(DEBUG) << "Relink relation of " << m.first->fullname_with_scope() << ": " << prior->fullname_with_scope() | |||||
| << " -> " << behind->fullname_with_scope(); | |||||
| } | |||||
| } | |||||
| void ProcessControlDependCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) { | |||||
| std::unordered_map<AnfNodePtr, AnfNodePtrList> control_depend_info; | |||||
| AnfNodePtrList latter_to_be_erased; | |||||
| // Collect ControlDepend node and its input and output nodes. | |||||
| for (auto &[node, node_rel] : (*node_rels)) { | |||||
| if (!IsPrimitiveCNode(node, prim::kPrimControlDepend)) { | |||||
| continue; | |||||
| } | |||||
| auto [direct_relation, real_relations] = FindRelationOfControlDepend(node, node_rels); | |||||
| auto &[prior_node, behind_node] = direct_relation; | |||||
| auto &[real_prior_nodes, real_behind_nodes] = real_relations; | |||||
| (*node_rels)[prior_node].nexts.erase(node); | |||||
| (*node_rels)[behind_node].nexts.erase(node); | |||||
| node_rel.pres.erase(prior_node); | |||||
| node_rel.pres.erase(behind_node); | |||||
| for (auto &first_node : real_prior_nodes) { | |||||
| for (auto &second_node : real_behind_nodes) { | |||||
| MS_EXCEPTION_IF_NULL(first_node); | |||||
| MS_EXCEPTION_IF_NULL(second_node); | |||||
| control_depend_info.insert({node, {first_node, second_node}}); | |||||
| } | |||||
| } | |||||
| latter_to_be_erased.push_back(node); | |||||
| } | |||||
| // Delete ControlDepend node before relink its relation. | |||||
| for (const auto &node : latter_to_be_erased) { | |||||
| node_rels->erase(node); | |||||
| } | |||||
| // Rebuild relation between prior and behind node. | |||||
| ReLinkNodesOfControlDependByRelation(control_depend_info, node_rels); | |||||
| } | |||||
| void ProcessTailMakeTupleCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) { | |||||
| AnfNodePtrList latter_to_be_erased; | |||||
| for (auto &[node, node_rel] : (*node_rels)) { | |||||
| if (!IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { | |||||
| continue; | |||||
| } | |||||
| AnfNodePtrList check_next_list; | |||||
| check_next_list.push_back(node); | |||||
| bool disinterested = false; | |||||
| for (auto &successor : node_rel.nexts) { | |||||
| if (!IsPrimitiveCNode(successor, prim::kPrimTupleGetItem)) { | |||||
| disinterested = true; | |||||
| break; | |||||
| } | |||||
| check_next_list.push_back(successor); | |||||
| } | |||||
| if (disinterested) { | |||||
| continue; | |||||
| } | |||||
| if (!std::all_of(check_next_list.cbegin(), check_next_list.cend(), | |||||
| [&node_rels](const AnfNodePtr &n) -> bool { return (*node_rels)[n].nexts.empty(); })) { | |||||
| continue; | |||||
| } | |||||
| latter_to_be_erased.push_back(node); | |||||
| } | |||||
| // Delete Tail MakeTuple(including its getitem nodes). | |||||
| for (const auto &node : latter_to_be_erased) { | |||||
| for (auto &pre : (*node_rels)[node].pres) { | |||||
| (*node_rels)[pre].nexts.erase(node); | |||||
| } | |||||
| // Tail MakeTuple is just be consumed by nothing or invalid getitem node. | |||||
| for (auto &getitem : (*node_rels)[node].nexts) { | |||||
| node_rels->erase(getitem); | |||||
| } | |||||
| node_rels->erase(node); | |||||
| } | |||||
| } | |||||
| bool IsSingleInputNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) { | |||||
| if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() == 1) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool IsSingleOutputNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) { | |||||
| if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() == 1) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool IsMultiInputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) { | |||||
| if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() > 1) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool IsMultiOutputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) { | |||||
| if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() > 1) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool IsNoInputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) { | |||||
| if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() == 0) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool IsNoOutputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) { | |||||
| if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() == 0) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| void ProcessLocalStructure(OrderedMap<AnfNodePtr, NodeRelation> *node_rels, std::set<AnfNodePtr> *virtual_noout_nodes, | |||||
| std::set<AnfNodePtr> *ignore_noin_nodes) { | |||||
| // 1. Local relation | |||||
| // Graph as following left part, relation D->B and D->E(D is a no input node) | |||||
| // will make B and E to be multiply inputs node. | |||||
| // But for parallel, this local relation can ignore for B and E, which make | |||||
| // them be able to be paralleled. | |||||
| // | |||||
| // ************************************ | |||||
| // * * | |||||
| // * | | * | |||||
| // * A D A D * | |||||
| // * | /| | / \ * | |||||
| // * | C | | C F * | |||||
| // * |/ / | | | * | |||||
| // * B F ====> B x x * | |||||
| // * | / | * | |||||
| // * |/ | * | |||||
| // * E E * | |||||
| // * | | * | |||||
| // * * | |||||
| // ************************************ | |||||
| AnfNodePtrList no_input_nodes; | |||||
| for (const auto &node_rel : *node_rels) { | |||||
| auto &node = node_rel.first; | |||||
| if (IsNoInputsNode(*node_rels, node)) { | |||||
| no_input_nodes.push_back(node); | |||||
| } | |||||
| } | |||||
| std::vector<std::pair<AnfNodePtr, AnfNodePtr>> latter_delete; | |||||
| for (const auto &ninode : no_input_nodes) { | |||||
| AnfNodePtrList cnexts((*node_rels)[ninode].nexts.begin(), (*node_rels)[ninode].nexts.end()); | |||||
| for (const auto &n : cnexts) { | |||||
| AnfNodePtr serial_tail = ninode; | |||||
| AnfNodePtr cur_node = n; | |||||
| while (IsSingleInputNode(*node_rels, cur_node) && IsSingleOutputNode(*node_rels, cur_node)) { | |||||
| serial_tail = cur_node; | |||||
| cur_node = *((*node_rels)[cur_node].nexts.begin()); | |||||
| } | |||||
| latter_delete.emplace_back(serial_tail, cur_node); | |||||
| } | |||||
| } | |||||
| // Delete relation. | |||||
| for (const auto &[serial_tail, cur_node] : latter_delete) { | |||||
| virtual_noout_nodes->insert(serial_tail); | |||||
| ignore_noin_nodes->insert(cur_node); | |||||
| (*node_rels)[serial_tail].nexts.erase(cur_node); | |||||
| (*node_rels)[cur_node].pres.erase(serial_tail); | |||||
| MS_LOG(INFO) << "Process local relation delete relation: " << serial_tail->fullname_with_scope() << " -> " | |||||
| << cur_node->fullname_with_scope(); | |||||
| } | |||||
| } | |||||
| std::tuple<AnfNodePtrList, AnfNodePtrList, AnfNodePtrList, AnfNodePtrList> GetInterestNodeIds( | |||||
| const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const std::set<AnfNodePtr> &virtual_noout_nodes, | |||||
| const std::set<AnfNodePtr> &ignore_noin_nodes) { | |||||
| AnfNodePtrList multi_inputs_nodes, multi_outputs_nodes, no_input_nodes, no_output_nodes; | |||||
| std::list<std::function<void(const AnfNodePtr &)>> func_list = { | |||||
| [&node_rels, &multi_inputs_nodes](const AnfNodePtr &node) { | |||||
| if (IsMultiInputsNode(node_rels, node)) { | |||||
| multi_inputs_nodes.push_back(node); | |||||
| } | |||||
| }, | |||||
| [&node_rels, &multi_outputs_nodes](const AnfNodePtr &node) { | |||||
| if (IsMultiOutputsNode(node_rels, node)) { | |||||
| multi_outputs_nodes.push_back(node); | |||||
| } | |||||
| }, | |||||
| [&node_rels, &no_input_nodes, &ignore_noin_nodes](const AnfNodePtr &node) { | |||||
| if (IsNoInputsNode(node_rels, node) && ignore_noin_nodes.count(node) == 0) { | |||||
| no_input_nodes.push_back(node); | |||||
| } | |||||
| }, | |||||
| [&node_rels, &no_output_nodes, &virtual_noout_nodes](const AnfNodePtr &node) { | |||||
| if (IsNoOutputsNode(node_rels, node) && virtual_noout_nodes.count(node) == 0) { | |||||
| no_output_nodes.push_back(node); | |||||
| } | |||||
| }}; | |||||
| for (const auto &node_rel : node_rels) { | |||||
| for (const auto &func : func_list) { | |||||
| func(node_rel.first); | |||||
| } | |||||
| } | |||||
| return std::make_tuple(multi_inputs_nodes, multi_outputs_nodes, no_input_nodes, no_output_nodes); | |||||
| } | |||||
| bool WhiteOpsFilter(const AnfNodePtr &node) { | |||||
| std::vector<PrimitivePtr> whiteable_ops = {}; // Not special for now. | |||||
| return session::AnfRuntimeAlgorithm::IsGraphKernel(node) || IsOneOf(node, whiteable_ops); | |||||
| } | |||||
| std::vector<AnfNodePtrList> SearchFromNodes(const AnfNodePtrList &nodes, | |||||
| const std::function<bool(const AnfNodePtr &)> &filter_func, | |||||
| const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, bool is_backward, | |||||
| std::set<AnfNodePtr> *seen) { | |||||
| // Start from multi-inputs node, stop on seen node or multi-inputs or multi-outputs nodes. | |||||
| // For backward search, the other multi-inputs node can be contained in. | |||||
| // For forward search, the other multi-outputs node can be contained in. | |||||
| auto get_contain_node_set = is_backward ? [](const NodeRelation &info) { return info.pres; } | |||||
| : [](const NodeRelation &info) { return info.nexts; }; | |||||
| auto get_exclude_node_set = is_backward ? [](const NodeRelation &info) { return info.nexts; } | |||||
| : [](const NodeRelation &info) { return info.pres; }; | |||||
| std::vector<AnfNodePtrList> group; | |||||
| for (const auto &node : nodes) { | |||||
| AnfNodePtrList stream; | |||||
| AnfNodePtr n = node; | |||||
| for (auto iter = node_rels.find(n); | |||||
| seen->count(n) == 0 && iter != node_rels.end() && get_exclude_node_set(iter->second).size() <= 1; | |||||
| iter = node_rels.find(n)) { | |||||
| if (filter_func(n)) { | |||||
| stream.push_back(n); | |||||
| seen->insert(n); | |||||
| } | |||||
| if (get_contain_node_set(iter->second).size() != 1) { | |||||
| break; | |||||
| } | |||||
| n = *(get_contain_node_set(iter->second).begin()); | |||||
| } | |||||
| if (stream.size() > 0) { | |||||
| group.push_back(stream); | |||||
| } | |||||
| } | |||||
| if (group.size() == 1) { | |||||
| for (const auto &drop : group[0]) { | |||||
| seen->erase(drop); | |||||
| } | |||||
| group.clear(); | |||||
| } | |||||
| return group; | |||||
| } | |||||
| void SearchStreamFromMultiRelationNode(const AnfNodePtrList &multi_nodes, | |||||
| const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, bool is_backward, | |||||
| std::vector<std::vector<AnfNodePtrList>> *groups, std::set<AnfNodePtr> *seen) { | |||||
| auto get_related_nodes = is_backward ? [](const NodeRelation &info) { return info.pres; } | |||||
| : [](const NodeRelation &info) { return info.nexts; }; | |||||
| for (const auto &node : multi_nodes) { | |||||
| if (auto iter = node_rels.find(node); iter != node_rels.end()) { | |||||
| const auto &pre_nodes = get_related_nodes(iter->second); | |||||
| AnfNodePtrList related_nodes(pre_nodes.begin(), pre_nodes.end()); | |||||
| groups->push_back(SearchFromNodes(related_nodes, WhiteOpsFilter, node_rels, is_backward, seen)); | |||||
| } | |||||
| } | |||||
| // Erase empty groups. | |||||
| for (auto iter = groups->begin(); iter != groups->end();) { | |||||
| if (iter->size() == 0) { | |||||
| iter = groups->erase(iter); | |||||
| } else { | |||||
| ++iter; | |||||
| } | |||||
| } | |||||
| } | |||||
| void SearchStreamFromUnidirectionalNode(const AnfNodePtrList &ud_nodes, | |||||
| const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, bool is_backward, | |||||
| std::vector<std::vector<AnfNodePtrList>> *groups, std::set<AnfNodePtr> *seen) { | |||||
| groups->push_back(SearchFromNodes(ud_nodes, WhiteOpsFilter, node_rels, is_backward, seen)); | |||||
| // Erase empty groups. | |||||
| for (auto iter = groups->begin(); iter != groups->end();) { | |||||
| if (iter->size() == 0) { | |||||
| iter = groups->erase(iter); | |||||
| } else { | |||||
| ++iter; | |||||
| } | |||||
| } | |||||
| } | |||||
| std::string DumpNode(const AnfNodePtr &node) { | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| std::stringstream buf; | |||||
| buf << (AnfAlgo::IsGraphKernel(cnode) ? "[graph]" : "[primitive]") << cnode->fullname_with_scope() << "|" | |||||
| << cnode->ToString(); | |||||
| return buf.str(); | |||||
| } | |||||
| void DumpParallelGroups(const std::vector<std::vector<AnfNodePtrList>> &groups) { | |||||
| MS_LOG(INFO) << "There are " << groups.size() << " parallel groups, their detail is: "; | |||||
| int i = 0; | |||||
| for (const auto group : groups) { | |||||
| std::stringstream buf; | |||||
| buf << "[" << i << " group] " << group.size() << ":\n"; | |||||
| for (const auto nodes : group) { | |||||
| buf << " " << nodes.size() << ": [<"; | |||||
| for (const auto node : nodes) { | |||||
| buf << "(" << DumpNode(node) << ") -> "; | |||||
| } | |||||
| buf << ">]\n"; | |||||
| } | |||||
| i++; | |||||
| MS_LOG(INFO) << buf.str(); | |||||
| } | |||||
| } | |||||
| void DumpParallelFusionDetail(const AnfNodePtrList &source, const AnfNodePtr &target) { | |||||
| std::stringstream buf; | |||||
| buf << "Parallel fusion detail: "; | |||||
| for (const auto &node : source) { | |||||
| buf << "(" << DumpNode(node) << ") + "; | |||||
| } | |||||
| buf << "==>" | |||||
| << "(" << DumpNode(target) << ")"; | |||||
| MS_LOG(INFO) << buf.str(); | |||||
| } | |||||
| } // namespace | |||||
| OrderedMap<AnfNodePtr, NodeRelation> ParallelOpFusion::GenAnalysisGraph(const AnfNodePtrList &nodes) { | |||||
| // Based on anf node input information, build a simple graph for latter analyzation. | |||||
| OrderedMap<AnfNodePtr, NodeRelation> node_rels; | |||||
| auto get_info = [&node_rels](const AnfNodePtr &node) { | |||||
| if (node_rels.count(node) == 0) { | |||||
| node_rels.insert({node, NodeRelation()}); | |||||
| } | |||||
| return &(node_rels[node]); | |||||
| }; | |||||
| for (const auto &node : nodes) { | |||||
| if (!node->isa<CNode>()) { | |||||
| continue; | |||||
| } | |||||
| auto prior_node = get_info(node); | |||||
| for (const auto &input : (node->cast<CNodePtr>())->inputs()) { | |||||
| // Parameter for ControlDepend when depend mode is 1. | |||||
| if (!input->isa<CNode>() && !input->isa<Parameter>()) { | |||||
| continue; | |||||
| } | |||||
| auto behind_node = get_info(input); | |||||
| prior_node->pres.insert(input); | |||||
| behind_node->nexts.insert(node); | |||||
| } | |||||
| } | |||||
| ProcessDependCNode(&node_rels); | |||||
| ProcessControlDependCNode(&node_rels); | |||||
| ProcessThroughPassCNode( | |||||
| [](const AnfNodePtr &node) { | |||||
| return IsOneOf(node, {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimSqueeze, prim::kPrimTupleGetItem}); | |||||
| }, | |||||
| &node_rels); | |||||
| ProcessThroughPassCNode([](const AnfNodePtr &node) { return node->isa<Parameter>(); }, &node_rels); | |||||
| ProcessTailMakeTupleCNode(&node_rels); | |||||
| ProcessLocalStructure(&node_rels, &virtual_noout_nodes_, &ignore_noin_nodes_); | |||||
| return node_rels; | |||||
| } | |||||
| std::vector<std::vector<AnfNodePtrList>> ParallelOpFusion::SearchParallelGroups( | |||||
| const OrderedMap<AnfNodePtr, NodeRelation> &node_rels) { | |||||
| // Get interesting nodes: multi-inputs nodes, multi-outputs nodes, no input nodes and no output nodes. | |||||
| auto [mul_ins_nodes, mul_outs_nodes, no_in_nodes, no_out_nodes] = | |||||
| GetInterestNodeIds(node_rels, virtual_noout_nodes_, ignore_noin_nodes_); | |||||
| // Get streams and group them | |||||
| std::set<AnfNodePtr> seen; | |||||
| std::vector<std::vector<AnfNodePtrList>> groups; | |||||
| SearchStreamFromMultiRelationNode(mul_ins_nodes, node_rels, true, &groups, &seen); | |||||
| SearchStreamFromUnidirectionalNode(no_out_nodes, node_rels, true, &groups, &seen); | |||||
| SearchStreamFromMultiRelationNode(mul_outs_nodes, node_rels, false, &groups, &seen); | |||||
| SearchStreamFromUnidirectionalNode(no_in_nodes, node_rels, false, &groups, &seen); | |||||
| DumpParallelGroups(groups); | |||||
| return groups; | |||||
| } | |||||
| std::tuple<AnfNodePtrList, std::vector<int>> ParallelOpFusion::GetAvaliableNodesByOffset( | |||||
| int start, const std::vector<int> &offsets, const std::vector<bool> &used, const AnfNodePtrList &nodes, | |||||
| const std::set<int> &excludes) { | |||||
| // Get unused nodes by offset index, the result will contain the node with start index. | |||||
| int node_limit = nodes.size(); | |||||
| if (start >= node_limit) { | |||||
| MS_LOG(EXCEPTION) << "Index offset is exceed the limit of given nodes."; | |||||
| } | |||||
| AnfNodePtrList target_nodes = {nodes[start]}; | |||||
| std::vector<int> valid_indices; | |||||
| std::vector<int> unused; | |||||
| for (size_t i = start; i < used.size(); ++i) { | |||||
| if (!used[i] && excludes.count(i) == 0) { | |||||
| unused.push_back(i); | |||||
| } | |||||
| } | |||||
| int limit = unused.size(); | |||||
| for (auto offset : offsets) { | |||||
| if (offset >= limit) { | |||||
| MS_LOG(EXCEPTION) << "Index offset is exceed the limit of unused nodes."; | |||||
| } | |||||
| if (unused[offset] >= node_limit) { | |||||
| MS_LOG(EXCEPTION) << "Index offset is exceed the limit of nodes."; | |||||
| } | |||||
| valid_indices.push_back(unused[offset]); | |||||
| target_nodes.push_back(nodes[unused[offset]]); | |||||
| } | |||||
| return std::make_tuple(target_nodes, valid_indices); | |||||
| } | |||||
| std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::DoSearchInSortedCandidates( | |||||
| size_t origin_size, const AnfNodePtrList &candidates, std::map<AnfNodePtr, int> *origin_indices, | |||||
| std::map<AnfNodePtr, int> *sorted_indices) { | |||||
| auto get_index = [](std::map<AnfNodePtr, int> *indices, const AnfNodePtr &node) -> int { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (indices->find(node) == indices->end()) { | |||||
| MS_LOG(EXCEPTION) << "There is no index record for node " << node->ToString(); | |||||
| } | |||||
| return (*indices)[node]; | |||||
| }; | |||||
| std::vector<ParallelInfo> parallel_infos; | |||||
| std::vector<bool> origin_candidates_used(origin_size, false); | |||||
| std::vector<bool> sorted_candidates_used(candidates.size(), false); | |||||
| for (size_t i = 0; i < candidates.size(); ++i) { | |||||
| if (sorted_candidates_used[i]) { | |||||
| continue; | |||||
| } | |||||
| int max_benefit = 0; | |||||
| ParallelInfo best_parallel_info; | |||||
| std::set<int> bad_set; | |||||
| size_t unused_num = 0; | |||||
| for (size_t j = i + 1; j < sorted_candidates_used.size(); ++j) { | |||||
| unused_num += sorted_candidates_used[j] ? 0 : 1; | |||||
| } | |||||
| if (unused_num < 1) { | |||||
| break; | |||||
| } | |||||
| unused_num = std::min(unused_num, config_.max_num_for_fuse() - 1); | |||||
| size_t begin = 1, end = unused_num; | |||||
| while (begin <= end) { | |||||
| size_t mid = (begin + end) / 2; | |||||
| std::vector<int> tc(mid); | |||||
| std::iota(tc.begin(), tc.end(), 1); | |||||
| AnfNodePtrList other_candidates; | |||||
| std::tie(other_candidates, std::ignore) = | |||||
| GetAvaliableNodesByOffset(i, tc, sorted_candidates_used, candidates, std::set<int>()); | |||||
| int benefit; | |||||
| std::tie(std::ignore, benefit) = cost_model_ptr_->CalFuseInfo(other_candidates); | |||||
| if (benefit > 0) { | |||||
| begin = mid + 1; | |||||
| } else { | |||||
| end = mid - 1; | |||||
| } | |||||
| } | |||||
| if (begin > 1) { | |||||
| std::vector<int> tc(begin - 1); | |||||
| std::iota(tc.begin(), tc.end(), 1); | |||||
| AnfNodePtrList other_candidates; | |||||
| std::tie(other_candidates, std::ignore) = | |||||
| GetAvaliableNodesByOffset(i, tc, sorted_candidates_used, candidates, std::set<int>()); | |||||
| auto [dim_infos, benefit] = cost_model_ptr_->CalFuseInfo(other_candidates); | |||||
| if (benefit <= 0) { | |||||
| MS_LOG(EXCEPTION) << "Internal error in candidate search!"; | |||||
| } | |||||
| max_benefit = benefit; | |||||
| best_parallel_info = ParallelInfo(other_candidates, dim_infos); | |||||
| i += begin - 1; | |||||
| } | |||||
| if (max_benefit > 0) { | |||||
| parallel_infos.push_back(best_parallel_info); | |||||
| for (const auto &node : best_parallel_info.nodes()) { | |||||
| sorted_candidates_used[get_index(sorted_indices, node)] = true; | |||||
| origin_candidates_used[get_index(origin_indices, node)] = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| // Current nodes is not suitable to fuse, so pop first node to try other fusion possibility. | |||||
| if (parallel_infos.size() == 0) { | |||||
| origin_candidates_used[get_index(origin_indices, candidates[0])] = true; | |||||
| } | |||||
| return std::make_tuple(origin_candidates_used, parallel_infos); | |||||
| } | |||||
| std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::SearchFuseNodesInCandidates( | |||||
| const AnfNodePtrList &cs) { | |||||
| std::map<AnfNodePtr, int> origin_indices; | |||||
| std::vector<size_t> indices; | |||||
| for (size_t i = 0; i < cs.size(); ++i) { | |||||
| if (cs[i]) { | |||||
| origin_indices.insert({cs[i], i}); | |||||
| indices.push_back(i); | |||||
| } | |||||
| } | |||||
| // A calculated heavy node can cover more lighter nodes' cost, so sort them first. | |||||
| std::map<size_t, int> cal_amounts; | |||||
| for (auto id : indices) { | |||||
| cal_amounts[id] = cost_model_ptr_->GetNodeCalAmount(cs[id]); | |||||
| } | |||||
| std::sort(indices.begin(), indices.end(), | |||||
| [&cal_amounts](size_t a, size_t b) { return cal_amounts[a] > cal_amounts[b]; }); | |||||
| AnfNodePtrList candidates; | |||||
| for (size_t i = 0; i < indices.size(); ++i) { | |||||
| candidates.push_back(cs[indices[i]]); | |||||
| } | |||||
| std::map<AnfNodePtr, int> sorted_indices; | |||||
| for (size_t i = 0; i < candidates.size(); ++i) { | |||||
| sorted_indices.insert({candidates[i], i}); | |||||
| } | |||||
| return DoSearchInSortedCandidates(cs.size(), candidates, &origin_indices, &sorted_indices); | |||||
| } | |||||
| void ParallelOpFusion::SearchFuseNodesInParallelGroup(const std::vector<AnfNodePtrList> &group, | |||||
| std::vector<ParallelInfo> *parallel_infos) { | |||||
| std::vector<AnfNodePtrList::const_iterator> tails; | |||||
| std::vector<AnfNodePtrList::const_iterator> ended; | |||||
| for (const auto &node_list : group) { | |||||
| tails.push_back(node_list.begin()); | |||||
| ended.push_back(node_list.end()); | |||||
| } | |||||
| auto get_candidates = [&tails, &ended]() { | |||||
| AnfNodePtrList candidates; | |||||
| for (size_t id = 0; id < tails.size(); ++id) { | |||||
| candidates.push_back(tails[id] != ended[id] ? *tails[id] : AnfNodePtr()); | |||||
| } | |||||
| return candidates; | |||||
| }; | |||||
| auto update_tails = [&tails](const std::vector<bool> &used) { | |||||
| if (used.size() != tails.size()) { | |||||
| MS_LOG(EXCEPTION) << "Judged nodes size is not equal to left ones!"; | |||||
| } | |||||
| for (size_t id = 0; id < used.size(); ++id) { | |||||
| if (used[id]) { | |||||
| tails[id]++; | |||||
| } | |||||
| } | |||||
| }; | |||||
| auto valid_candidate_num = [](const AnfNodePtrList &cs) { | |||||
| return std::count_if(cs.begin(), cs.end(), [](const AnfNodePtr &n) { return n != nullptr; }); | |||||
| }; | |||||
| auto candidates = get_candidates(); | |||||
| while (valid_candidate_num(candidates) > 1) { | |||||
| auto [used, fnds] = SearchFuseNodesInCandidates(candidates); | |||||
| std::transform(fnds.cbegin(), fnds.cend(), std::back_insert_iterator(*parallel_infos), | |||||
| [](const ParallelInfo &pi) { return pi; }); | |||||
| update_tails(used); | |||||
| candidates = get_candidates(); | |||||
| } | |||||
| } | |||||
| std::vector<ParallelInfo> ParallelOpFusion::SearchFusableParallelCNodes( | |||||
| const std::vector<std::vector<AnfNodePtrList>> &groups) { | |||||
| // Find core-fusable groups with cost model. | |||||
| std::vector<ParallelInfo> parallel_infos; | |||||
| for (const auto &group : groups) { | |||||
| SearchFuseNodesInParallelGroup(group, ¶llel_infos); | |||||
| } | |||||
| return parallel_infos; | |||||
| } | |||||
| void ParallelOpFusion::SetFusedParallelOpAttrToReturnNode(const ParallelInfo ¶llel_info) { | |||||
| for (size_t i = 0; i < parallel_info.GetSize(); ++i) { | |||||
| const auto &fuse_nodes = parallel_info.nodes(); | |||||
| std::vector<size_t> info = {i, std::dynamic_pointer_cast<CommonDimInfo>(parallel_info.dims()[i])->dim_info()}; | |||||
| if (!AnfAlgo::IsGraphKernel(fuse_nodes[i])) { | |||||
| MakeCNodeSafeForAttr(fuse_nodes[i]); | |||||
| AnfAlgo::SetNodeAttr(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), fuse_nodes[i]); | |||||
| } else { | |||||
| auto node_g = GetValueNode<FuncGraphPtr>((fuse_nodes[i]->cast<CNodePtr>())->input(0)); | |||||
| auto out_node = node_g->output(); | |||||
| if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) { | |||||
| auto inputs = out_node->cast<CNodePtr>()->inputs(); | |||||
| for (size_t j = 1; j < inputs.size(); ++j) { | |||||
| MakeCNodeSafeForAttr(inputs[j]); | |||||
| AnfAlgo::SetNodeAttr(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), inputs[j]); | |||||
| } | |||||
| } else { | |||||
| MakeCNodeSafeForAttr(out_node); | |||||
| AnfAlgo::SetNodeAttr(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), out_node); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| void PostProcessForNewSubGraphCNode(const AnfNodePtr &node, const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||||
| auto mng = kernel_graph->manager(); | |||||
| if (mng == nullptr) { | |||||
| mng = Manage(kernel_graph, true); | |||||
| kernel_graph->set_manager(mng); | |||||
| } | |||||
| const auto &users = mng->node_users()[node]; | |||||
| std::vector<std::pair<AnfNodePtr, int>> sons; | |||||
| for (const auto &[user, index] : users) { | |||||
| if (!IsPrimitiveCNode(user, prim::kPrimTupleGetItem)) { | |||||
| sons.emplace_back(user, index); | |||||
| continue; | |||||
| } | |||||
| auto &[fake_first_grad_son, grad_index] = *((mng->node_users()[user]).begin()); | |||||
| sons.emplace_back(fake_first_grad_son, grad_index); | |||||
| } | |||||
| AnfNodePtrList latter_to_delete; | |||||
| for (const auto &[son, index] : sons) { | |||||
| if (!IsPrimitiveCNode(son, prim::kPrimDepend) || index != kDependAttachNodeIndex) { | |||||
| continue; | |||||
| } | |||||
| latter_to_delete.push_back(son); | |||||
| } | |||||
| if (latter_to_delete.empty()) { | |||||
| return; | |||||
| } | |||||
| std::vector<AnfNodePtr>::iterator delete_begin = latter_to_delete.begin(); | |||||
| if (latter_to_delete.size() == sons.size()) { | |||||
| // Left one Depend node relation and delete others! | |||||
| ++delete_begin; | |||||
| } | |||||
| for (; delete_begin != latter_to_delete.end(); ++delete_begin) { | |||||
| auto depend_anfnode = *delete_begin; | |||||
| auto depend_cnode = depend_anfnode->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(depend_cnode); | |||||
| auto depend_prior_node = depend_cnode->input(kRealInputIndexInDepend); | |||||
| mng->Replace(depend_anfnode, depend_prior_node); | |||||
| } | |||||
| } | |||||
| bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector<ParallelInfo> ¶llel_infos, | |||||
| const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||||
| bool changed = false; | |||||
| for (size_t i = 0; i < parallel_infos.size(); ++i) { | |||||
| const auto &fuse_nodes = parallel_infos[i].nodes(); | |||||
| if (fuse_nodes.size() <= 1) { | |||||
| continue; | |||||
| } | |||||
| changed = true; | |||||
| SetFusedParallelOpAttrToReturnNode(parallel_infos[i]); | |||||
| AnfNodePtr sg_node; | |||||
| std::tie(sg_node, std::ignore) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "parallel"); | |||||
| PostProcessForNewSubGraphCNode(sg_node, kernel_graph); | |||||
| DumpParallelFusionDetail(fuse_nodes, sg_node); | |||||
| } | |||||
| return changed; | |||||
| } | |||||
| bool ParallelOpFusion::Run(const FuncGraphPtr &graph) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| auto kernel_graph = graph->cast<std::shared_ptr<session::KernelGraph>>(); | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| cost_model_ptr_ = ParellelCostModelWarehouse::Instance().GetParallelCostModel(target_); | |||||
| MS_EXCEPTION_IF_NULL(cost_model_ptr_); | |||||
| auto nodes = TopoSort(kernel_graph->get_return()); | |||||
| std::reverse(nodes.begin(), nodes.end()); | |||||
| auto node_rels = GenAnalysisGraph(nodes); | |||||
| auto groups = SearchParallelGroups(node_rels); | |||||
| auto parallel_infos = SearchFusableParallelCNodes(groups); | |||||
| // Create core-fuse subgraph and change origin graph. | |||||
| return CreateParallelOpSubGraphs(parallel_infos, kernel_graph); | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,122 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include <sstream> | |||||
| #include <string> | |||||
| #include <tuple> | |||||
| #include <vector> | |||||
| #include "base/base.h" | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| #include "backend/optimizer/graph_kernel/parallel_cost_model.h" | |||||
| #include "backend/session/kernel_graph.h" | |||||
| #include "utils/ms_context.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class ParallelInfo { | |||||
| public: | |||||
| ParallelInfo() = default; | |||||
| ParallelInfo(const AnfNodePtrList &nodes, const std::vector<DimInfoPtr> &dims) : nodes_(nodes), dims_(dims) {} | |||||
| ParallelInfo(const ParallelInfo &obj) { | |||||
| nodes_ = obj.nodes_; | |||||
| dims_ = obj.dims_; | |||||
| } | |||||
| ~ParallelInfo() = default; | |||||
| size_t GetSize() const { | |||||
| if (nodes_.size() != dims_.size()) { | |||||
| MS_LOG(EXCEPTION) << "Internal error in parallel info!"; | |||||
| } | |||||
| return nodes_.size(); | |||||
| } | |||||
| const AnfNodePtrList &nodes() const { return nodes_; } | |||||
| const std::vector<DimInfoPtr> &dims() const { return dims_; } | |||||
| private: | |||||
| AnfNodePtrList nodes_; | |||||
| std::vector<DimInfoPtr> dims_; | |||||
| }; | |||||
| class ParallelConfig { | |||||
| public: | |||||
| ParallelConfig() = default; | |||||
| explicit ParallelConfig(size_t max_n) : max_num_for_fuse_(max_n) {} | |||||
| explicit ParallelConfig(const ParallelConfig &obj) { max_num_for_fuse_ = obj.max_num_for_fuse_; } | |||||
| ~ParallelConfig() = default; | |||||
| size_t max_num_for_fuse() { return max_num_for_fuse_; } | |||||
| private: | |||||
| size_t max_num_for_fuse_{10}; // Too many nodes to fuse together may produce bad result. | |||||
| }; | |||||
| struct NodeRelation { | |||||
| public: | |||||
| NodeRelation() {} | |||||
| ~NodeRelation() = default; | |||||
| OrderedSet<AnfNodePtr> pres; | |||||
| OrderedSet<AnfNodePtr> nexts; | |||||
| }; | |||||
| class ParallelOpFusion : public Pass { | |||||
| public: | |||||
| ParallelOpFusion(const std::string &target, const ParallelConfig &config) | |||||
| : Pass("parallel_fusion"), target_(target), config_(config) {} | |||||
| ~ParallelOpFusion() override = default; | |||||
| bool Run(const FuncGraphPtr &graph) override; | |||||
| private: | |||||
| std::tuple<AnfNodePtrList, std::vector<int>> GetAvaliableNodesByOffset(int start, const std::vector<int> &offsets, | |||||
| const std::vector<bool> &used, | |||||
| const AnfNodePtrList &nodes, | |||||
| const std::set<int> &excludes); | |||||
| std::tuple<std::vector<bool>, std::vector<ParallelInfo>> DoSearchInSortedCandidates( | |||||
| size_t origin_size, const AnfNodePtrList &candidates, std::map<AnfNodePtr, int> *origin_indices, | |||||
| std::map<AnfNodePtr, int> *sorted_indices); | |||||
| std::tuple<std::vector<bool>, std::vector<ParallelInfo>> SearchFuseNodesInCandidates(const AnfNodePtrList &cs); | |||||
| void SearchFuseNodesInParallelGroup(const std::vector<AnfNodePtrList> &group, | |||||
| std::vector<ParallelInfo> *parallel_infos); | |||||
| std::vector<ParallelInfo> SearchFusableParallelCNodes(const std::vector<std::vector<AnfNodePtrList>> &groups); | |||||
| void SetFusedParallelOpAttrToReturnNode(const ParallelInfo ¶llel_info); | |||||
| bool CreateParallelOpSubGraphs(const std::vector<ParallelInfo> ¶llel_infos, | |||||
| const std::shared_ptr<session::KernelGraph> &kernel_graph); | |||||
| OrderedMap<AnfNodePtr, NodeRelation> GenAnalysisGraph(const AnfNodePtrList &nodes); | |||||
| std::vector<std::vector<AnfNodePtrList>> SearchParallelGroups(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels); | |||||
| std::string target_; | |||||
| ParallelConfig config_; | |||||
| ParallelCostModelPtr cost_model_ptr_; | |||||
| std::set<AnfNodePtr> virtual_noout_nodes_; | |||||
| std::set<AnfNodePtr> ignore_noin_nodes_; | |||||
| }; | |||||
| using ParallelOpFusionPtr = std::shared_ptr<ParallelOpFusion>; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_ | |||||
| @@ -43,6 +43,7 @@ | |||||
| #include "backend/optimizer/graph_kernel/arithmetic_simplify.h" | #include "backend/optimizer/graph_kernel/arithmetic_simplify.h" | ||||
| #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" | #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" | ||||
| #include "backend/optimizer/graph_kernel/clean_all_in_once.h" | #include "backend/optimizer/graph_kernel/clean_all_in_once.h" | ||||
| #include "backend/optimizer/graph_kernel/depend_formater.h" | |||||
| #include "backend/optimizer/graph_kernel/eliminate_redundant_output.h" | #include "backend/optimizer/graph_kernel/eliminate_redundant_output.h" | ||||
| #include "backend/optimizer/graph_kernel/tensor_promotion.h" | #include "backend/optimizer/graph_kernel/tensor_promotion.h" | ||||
| #include "backend/optimizer/graph_kernel/graph_kernel_splitter.h" | #include "backend/optimizer/graph_kernel/graph_kernel_splitter.h" | ||||
| @@ -51,6 +52,7 @@ | |||||
| #include "backend/optimizer/graph_kernel/graph_kernel_cse.h" | #include "backend/optimizer/graph_kernel/graph_kernel_cse.h" | ||||
| #include "backend/optimizer/graph_kernel/shape_ops_splitter.h" | #include "backend/optimizer/graph_kernel/shape_ops_splitter.h" | ||||
| #include "backend/optimizer/graph_kernel/value_graph_binder.h" | #include "backend/optimizer/graph_kernel/value_graph_binder.h" | ||||
| #include "backend/optimizer/graph_kernel/parallel_fusion.h" | |||||
| #include "backend/optimizer/pass/communication_op_fusion.h" | #include "backend/optimizer/pass/communication_op_fusion.h" | ||||
| #include "backend/optimizer/pass/getitem_tuple.h" | #include "backend/optimizer/pass/getitem_tuple.h" | ||||
| #include "common/trans.h" | #include "common/trans.h" | ||||
| @@ -179,6 +181,7 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_ | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | auto optimizer = std::make_shared<opt::GraphOptimizer>(); | ||||
| auto pm = std::make_shared<opt::PassManager>("graph_kernel_pm"); | auto pm = std::make_shared<opt::PassManager>("graph_kernel_pm"); | ||||
| std::vector<PrimitivePtr> duplicated_ops = {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimCast}; | std::vector<PrimitivePtr> duplicated_ops = {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimCast}; | ||||
| pm->AddPass(std::make_shared<opt::DependFormater>()); // Make more fusion opportunity. | |||||
| pm->AddPass(std::make_shared<opt::GraphKernelExpander>()); | pm->AddPass(std::make_shared<opt::GraphKernelExpander>()); | ||||
| pm->AddPass(std::make_shared<opt::ShapeOpsSplitter>(duplicated_ops)); | pm->AddPass(std::make_shared<opt::ShapeOpsSplitter>(duplicated_ops)); | ||||
| pm->AddPass(std::make_shared<opt::BasicOpsFusion>()); | pm->AddPass(std::make_shared<opt::BasicOpsFusion>()); | ||||
| @@ -196,7 +199,8 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_ | |||||
| // will be exposed, use GetitemTuple Pass to delete them. | // will be exposed, use GetitemTuple Pass to delete them. | ||||
| pm->AddPass(std::make_shared<opt::GetitemTuple>()); | pm->AddPass(std::make_shared<opt::GetitemTuple>()); | ||||
| pm->AddPass(std::make_shared<opt::AtomicCleanInsertter>()); | pm->AddPass(std::make_shared<opt::AtomicCleanInsertter>()); | ||||
| pm->AddPass(std::make_shared<opt::CleanAllInOnce>()); | |||||
| pm->AddPass(std::make_shared<opt::DependFormater>()); // Prevent fake loop in parallel fusion. | |||||
| pm->AddPass(std::make_shared<opt::ParallelOpFusion>(kGPUDevice, opt::ParallelConfig(7))); | |||||
| pm->AddPass(std::make_shared<opt::BindValueToGraph>()); | pm->AddPass(std::make_shared<opt::BindValueToGraph>()); | ||||
| optimizer->AddPassManager(pm); | optimizer->AddPassManager(pm); | ||||
| (void)optimizer->Optimize(kernel_graph); | (void)optimizer->Optimize(kernel_graph); | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -382,6 +382,7 @@ constexpr auto kAttrPadding = "padding"; | |||||
| constexpr auto kAttrIsGrad = "is_grad"; | constexpr auto kAttrIsGrad = "is_grad"; | ||||
| constexpr auto kAttrRecompute = "recompute"; | constexpr auto kAttrRecompute = "recompute"; | ||||
| constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute"; | constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute"; | ||||
| constexpr auto kAttrParallelDimInfo = "parallel_dim_info"; | |||||
| // attr value | // attr value | ||||
| constexpr auto kValueTargetSwitch = "target_switch"; | constexpr auto kValueTargetSwitch = "target_switch"; | ||||
| @@ -0,0 +1,54 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ========================================================================== | |||||
| """test graph parallel case""" | |||||
| import model | |||||
| def injective_graph(shape): | |||||
| gb = model.GraphBuilder() | |||||
| with gb.graph_scope('injective') as _: | |||||
| a1 = gb.tensor(shape, 'float32') | |||||
| a2 = gb.emit('Abs', a1) | |||||
| a3 = gb.emit('Abs', a2) | |||||
| gb.emit('Abs', a3) | |||||
| return gb.get()[0] | |||||
| def reduce_graph(shape, reduce_axis): | |||||
| gb = model.GraphBuilder() | |||||
| with gb.graph_scope('reduce') as _: | |||||
| a1 = gb.tensor(shape, 'float32') | |||||
| a2 = gb.emit('Abs', a1) | |||||
| a3 = gb.emit('Abs', a2) | |||||
| gb.emit('ReduceSum', a3, 'C', attrs={'reduce_axis': reduce_axis}) | |||||
| return gb.get()[0] | |||||
| def control_graph(shape): | |||||
| gb = model.GraphBuilder() | |||||
| with gb.graph_scope('control') as _: | |||||
| a1 = gb.tensor(shape, 'float32') | |||||
| a2 = gb.emit('Abs', a1) | |||||
| gb.emit('ControlDepend', a2) | |||||
| return gb.get()[0] | |||||
| def block_fusion(graphs): | |||||
| gain = model.parallel_estimate(graphs) | |||||
| print("fusion = {}, bottleneck = {}, gain = {}".format(gain.fusion_type, gain.bottleneck, gain.gain)) | |||||
| return gain.fusion_type == "block_fusion" and gain.gain > 0 | |||||
| if __name__ == "__main__": | |||||
| assert block_fusion([injective_graph([40, 1024]), injective_graph([40, 1024])]) | |||||
| assert block_fusion([reduce_graph([1024, 1024], [1]), injective_graph([24, 1024])]) | |||||
| assert not block_fusion([reduce_graph([1024, 1024], [1]), injective_graph([50, 1024])]) | |||||
| assert not block_fusion([reduce_graph([1024, 1024], [0, 1]), injective_graph([1024, 1024])]) | |||||
| assert block_fusion([control_graph([20, 128]), injective_graph([40, 1024])]) | |||||