From: @tronzhang Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -1 +1 @@ | |||||
| Subproject commit 94cb709ecaf5d1d869883dfe80cee7497dd0692c | |||||
| Subproject commit 24ba04df564fb3d2578e1b4324c760783b34d551 | |||||
| @@ -17,11 +17,12 @@ from .model import PrimLib | |||||
| class ParalGain: | class ParalGain: | ||||
| def __init__(self, fusion_type, bottleneck, gain, block_assign): | |||||
| def __init__(self, fusion_type, bottleneck, gain, block_assign, type_info): | |||||
| self.fusion_type = fusion_type | self.fusion_type = fusion_type | ||||
| self.bottleneck = bottleneck | self.bottleneck = bottleneck | ||||
| self.gain = gain | self.gain = gain | ||||
| self.block_assign = block_assign | self.block_assign = block_assign | ||||
| self.type_info = type_info | |||||
| class ScheduleAnalyzer: | class ScheduleAnalyzer: | ||||
| @@ -30,6 +31,7 @@ class ScheduleAnalyzer: | |||||
| MAX_SM = 80 # Volta | MAX_SM = 80 # Volta | ||||
| MAX_NUM_THREADS = 1024 | MAX_NUM_THREADS = 1024 | ||||
| MAX_BLOCK = 256 | MAX_BLOCK = 256 | ||||
| PIPELINE_OP_THREADHOLD = 5 | |||||
| def __init__(self, graph): | def __init__(self, graph): | ||||
| self.graph = graph | self.graph = graph | ||||
| @@ -132,11 +134,141 @@ class ScheduleAnalyzer: | |||||
| else: | else: | ||||
| self.default_analyze() | self.default_analyze() | ||||
| def suitable_to_pipeline(self): | |||||
| """judge whether is suitable to be pipeline optimized""" | |||||
| # Reduce is not suitable | |||||
| def _contain_reduce(ops): | |||||
| for op in ops: | |||||
| # Reduce may make the tiling bad. | |||||
| if PrimLib.primtives.get(op.prim, None) == PrimLib.REDUCE: | |||||
| return True | |||||
| return False | |||||
| suitable = True | |||||
| if _contain_reduce(self.ops): | |||||
| suitable = False | |||||
| return suitable | |||||
| @staticmethod | |||||
| def k_mean(data, class_n=2, exclude_id=()): | |||||
| """ | |||||
| Find k clusters in which element is close to each other. | |||||
| Args: | |||||
| data (list): Elements' information. | |||||
| class_n (int): Number of clusters wanted to be analyzed, default is 2. | |||||
| exclude_id (tuple[int]): The list of excluded element's index, default is (). | |||||
| Returns: | |||||
| classes (list[list[int]]): The list of clusters. Each cluster is a list of indices. | |||||
| """ | |||||
| def _cal_mean(classes): | |||||
| class_datas = [[data[cid] for cid in cls] for cls in classes] | |||||
| return [sum(cls) / len(cls) if cls else float('inf') for cls in class_datas] | |||||
| def _cal_distance(a, b): | |||||
| return abs(a - b) | |||||
| def _check_different(old_classes, new_classes): | |||||
| for o, n in zip(old_classes, new_classes): | |||||
| if o != n: | |||||
| return True | |||||
| return False | |||||
| if len(data) < class_n: | |||||
| return None | |||||
| classes = [] | |||||
| for i, _ in enumerate(data): | |||||
| if i in exclude_id: | |||||
| continue | |||||
| if len(classes) >= class_n: | |||||
| break | |||||
| classes.append([i]) | |||||
| changed = True | |||||
| while changed: | |||||
| new_classes = [[] for cls in classes] | |||||
| means = _cal_mean(classes) | |||||
| for idx, d in enumerate(data): | |||||
| if idx in exclude_id: | |||||
| continue | |||||
| min_idx = -1 | |||||
| min_dis = float('inf') | |||||
| for i, m in enumerate(means): | |||||
| cur_dis = _cal_distance(m, d) | |||||
| min_idx = i if min_dis > cur_dis else min_idx | |||||
| min_dis = cur_dis if min_dis > cur_dis else min_dis | |||||
| new_classes[min_idx].append(idx) | |||||
| changed = _check_different(classes, new_classes) | |||||
| classes = new_classes | |||||
| return classes | |||||
| @staticmethod | |||||
| def pipeline_fusion_analyze(blocks, op_sizes, exclude_id): | |||||
| """analyze whether the segments can be pipeline optimized""" | |||||
| # op size first, block second. | |||||
| def _simple_factor(block, op_size): | |||||
| return block + 5 * op_size | |||||
| def _take_second(elem): | |||||
| return elem[1] | |||||
| simple_indicators = [_simple_factor(b, s) | |||||
| for b, s in zip(blocks, op_sizes)] | |||||
| # 2 classes, one heavy, the other light | |||||
| classes = ScheduleAnalyzer.k_mean(simple_indicators, 2, exclude_id) | |||||
| if not classes: | |||||
| return [] | |||||
| means = [sum([simple_indicators[idx] for idx in cls]) / | |||||
| len(cls) if cls else float('inf') for cls in classes] | |||||
| # The target two clusters should be a heavy one and a light one. | |||||
| # The light one maybe suitable to run with pipeline optimized. | |||||
| classes_infos = [[cls, m] for cls, m in zip(classes, means)] | |||||
| classes_infos.sort(key=_take_second) | |||||
| pipeline_target = None | |||||
| for ci in classes_infos: | |||||
| if ci: | |||||
| pipeline_target = ci | |||||
| break | |||||
| pipeline_gids, pipeline_mean = pipeline_target | |||||
| if pipeline_mean > _simple_factor(float(ScheduleAnalyzer.MAX_SM) / len(blocks), | |||||
| ScheduleAnalyzer.PIPELINE_OP_THREADHOLD): | |||||
| return [] | |||||
| pipeline_blocks = [] | |||||
| pipeline_weight = len(pipeline_gids) | |||||
| # Try to make two paralleled at least. | |||||
| if pipeline_weight > 3 and pipeline_weight > len(blocks) / 2: | |||||
| if len(pipeline_gids[:pipeline_weight // 2]) > 1: | |||||
| pipeline_blocks.append(pipeline_gids[:pipeline_weight // 2]) | |||||
| if len(pipeline_gids[pipeline_weight // 2:]) > 1: | |||||
| pipeline_blocks.append(pipeline_gids[pipeline_weight // 2:]) | |||||
| elif pipeline_weight > 1: | |||||
| pipeline_blocks.append(pipeline_gids) | |||||
| return pipeline_blocks | |||||
| @staticmethod | |||||
| def fusion_consult(blocks, op_sizes, exclude_gid): | |||||
| """get a recommendation for parallel fusion""" | |||||
| # Default is block fusion | |||||
| fusion_type = "block_fusion" | |||||
| type_info = None | |||||
| activate_pipeline_optimization = False # Disable pipeline optimization for now. | |||||
| if activate_pipeline_optimization: | |||||
| pipeline_info = ScheduleAnalyzer.pipeline_fusion_analyze( | |||||
| blocks, op_sizes, exclude_gid) | |||||
| if pipeline_info: | |||||
| fusion_type = "block_pipeline_fusion" | |||||
| type_info = pipeline_info | |||||
| return fusion_type, type_info | |||||
| def block_parallel_estimate(graphs): | def block_parallel_estimate(graphs): | ||||
| """estimate block parallel gain""" | """estimate block parallel gain""" | ||||
| sum_block, max_weight, sum_weight, blocks = 0, 0, 0, [] | |||||
| for g in graphs: | |||||
| sum_block, max_weight, sum_weight, blocks, op_sizes, exclude_gid = 0, 0, 0, [], [], [] | |||||
| for gid, g in enumerate(graphs): | |||||
| s = ScheduleAnalyzer(g) | s = ScheduleAnalyzer(g) | ||||
| s.analyze() | s.analyze() | ||||
| sum_block += s.block_num | sum_block += s.block_num | ||||
| @@ -144,9 +276,14 @@ def block_parallel_estimate(graphs): | |||||
| max_weight = s.block_weight | max_weight = s.block_weight | ||||
| sum_weight += s.block_weight | sum_weight += s.block_weight | ||||
| blocks.append(s.block_num) | blocks.append(s.block_num) | ||||
| op_sizes.append(len(s.ops)) | |||||
| if not s.suitable_to_pipeline(): | |||||
| exclude_gid.append(gid) | |||||
| if sum_block > ScheduleAnalyzer.MAX_SM * 32: | if sum_block > ScheduleAnalyzer.MAX_SM * 32: | ||||
| return ParalGain("none", sum_weight, 0, []) | |||||
| return ParalGain("block_fusion", max_weight, sum_weight - max_weight, blocks) | |||||
| return ParalGain("none", sum_weight, 0, [0 for _ in graphs], None) | |||||
| fusion_type, type_info = ScheduleAnalyzer.fusion_consult(blocks, op_sizes, tuple(exclude_gid)) | |||||
| return ParalGain(fusion_type, max_weight, sum_weight - max_weight, blocks, type_info) | |||||
| def parallel_estimate(graphs): | def parallel_estimate(graphs): | ||||
| @@ -28,10 +28,8 @@ def estimate_ops(json_str: str): | |||||
| for gd in graph_descs: | for gd in graph_descs: | ||||
| graphs.append(model.load_composite(gd).graph) | graphs.append(model.load_composite(gd).graph) | ||||
| estimation = model.parallel_estimate(graphs) | estimation = model.parallel_estimate(graphs) | ||||
| if estimation.fusion_type == "block_fusion" and estimation.gain > 0: | |||||
| res = (estimation.block_assign, estimation.gain) | |||||
| else: | |||||
| res = ([0 for g in graphs], 0) | |||||
| res = (estimation.block_assign, estimation.gain, | |||||
| estimation.fusion_type, estimation.type_info) | |||||
| return res | return res | ||||
| except jd.JSONDecodeError: | except jd.JSONDecodeError: | ||||
| logger.error(traceback.format_exc()) | logger.error(traceback.format_exc()) | ||||
| @@ -557,30 +557,6 @@ bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node, nlohmann::j | |||||
| return true; | return true; | ||||
| } | } | ||||
| void AkgKernelJsonGenerator::SetParallelValueToJson(const std::string &processor, | |||||
| const std::map<size_t, size_t> &dim_infos, | |||||
| nlohmann::json *sub_fusion_json) { | |||||
| if (processor == kProcessorCuda) { | |||||
| std::vector<size_t> cnums; | |||||
| std::transform(dim_infos.cbegin(), dim_infos.cend(), std::back_insert_iterator(cnums), | |||||
| [](const std::pair<size_t, size_t> &dim) { return dim.second; }); | |||||
| (*sub_fusion_json)[kJsonKeyCoreNum] = cnums; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Parallel fusion not support " << processor << " now."; | |||||
| } | |||||
| } | |||||
| void AkgKernelJsonGenerator::AddParalleFusionJsonInfo(const std::string &processor, nlohmann::json *kernel_json) { | |||||
| nlohmann::json parallel_fusion_json; | |||||
| parallel_fusion_json[kJsonKeyFusionType] = "block_fusion"; | |||||
| std::vector<std::vector<std::string>> sgraphs; | |||||
| std::transform(sub_graphs_.cbegin(), sub_graphs_.cend(), std::back_insert_iterator(sgraphs), | |||||
| [](const std::pair<int, std::vector<std::string>> &sg) { return sg.second; }); | |||||
| parallel_fusion_json[kJsonKeySubGraph] = sgraphs; | |||||
| SetParallelValueToJson(processor, dim_infos_, ¶llel_fusion_json); | |||||
| (*kernel_json)[kJsonKeyParallelFusion] = parallel_fusion_json; | |||||
| } | |||||
| void AkgKernelJsonGenerator::GenStitchJson(const std::vector<AnfNodePtr> &anf_nodes, | void AkgKernelJsonGenerator::GenStitchJson(const std::vector<AnfNodePtr> &anf_nodes, | ||||
| std::map<AnfNodePtr, nlohmann::json> *node_json_map, | std::map<AnfNodePtr, nlohmann::json> *node_json_map, | ||||
| nlohmann::json *kernel_json) { | nlohmann::json *kernel_json) { | ||||
| @@ -633,12 +609,8 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf | |||||
| (*kernel_json)[kJsonKeyOutputDesc] = | (*kernel_json)[kJsonKeyOutputDesc] = | ||||
| CreateOutputsJson(anf_nodes, input_list, output_list, inputs_json, node_json_map); | CreateOutputsJson(anf_nodes, input_list, output_list, inputs_json, node_json_map); | ||||
| auto processor = GetProcessorStr(anf_nodes[0]); | |||||
| // Add parallel fusion information. | // Add parallel fusion information. | ||||
| if (!sub_graphs_.empty()) { | |||||
| AddParalleFusionJsonInfo(processor, kernel_json); | |||||
| } | |||||
| GenParallelJson(anf_nodes, input_list, output_list, node_json_map, kernel_json); | |||||
| size_t hash_id = std::hash<std::string>()(kernel_json->dump()); | size_t hash_id = std::hash<std::string>()(kernel_json->dump()); | ||||
| kernel_name_ = "Fused_"; | kernel_name_ = "Fused_"; | ||||
| @@ -660,7 +632,7 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf | |||||
| (*kernel_json)[kJsonKeyId] = GetOpCntInc(); | (*kernel_json)[kJsonKeyId] = GetOpCntInc(); | ||||
| (*kernel_json)[kJsonKeyOp] = kernel_name_; | (*kernel_json)[kJsonKeyOp] = kernel_name_; | ||||
| (*kernel_json)[kJsonKeyPlatform] = "AKG"; | (*kernel_json)[kJsonKeyPlatform] = "AKG"; | ||||
| (*kernel_json)[kJsonKeyProcess] = processor; | |||||
| (*kernel_json)[kJsonKeyProcess] = GetProcessorStr(anf_nodes[0]); | |||||
| (*kernel_json)[kJsonKeyComposite] = true; | (*kernel_json)[kJsonKeyComposite] = true; | ||||
| (*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString() + "." + fg->debug_info()->get_id(); | (*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString() + "." + fg->debug_info()->get_id(); | ||||
| @@ -755,6 +727,70 @@ nlohmann::json AkgKernelJsonGenerator::CreateInputsJson(const std::vector<AnfNod | |||||
| return inputs_json; | return inputs_json; | ||||
| } | } | ||||
| void AkgKernelJsonGenerator::GenParallelJson(const std::vector<AnfNodePtr> &anf_nodes, | |||||
| const std::vector<AnfNodePtr> &input_list, | |||||
| const std::vector<AnfNodePtr> &output_list, | |||||
| const std::map<AnfNodePtr, nlohmann::json> &node_json_map, | |||||
| nlohmann::json *kernel_json) { | |||||
| std::map<size_t, std::pair<size_t, std::vector<std::string>>> sub_graphs_info; | |||||
| std::string fusion_type; | |||||
| std::vector<std::vector<int>> type_info; | |||||
| auto output_index = GetOutputIndex(anf_nodes, input_list, output_list); | |||||
| for (size_t i = 0; i < output_index.size(); ++i) { | |||||
| auto [tmp_output, tmp_output_index] = output_index[i]; | |||||
| bool found = std::any_of(input_list.cbegin(), input_list.cend(), | |||||
| [&tmp_output](const AnfNodePtr &in) { return tmp_output == in; }); | |||||
| if (!found) { | |||||
| auto tcnode = tmp_output->cast<CNodePtr>(); | |||||
| if (tcnode == nullptr) { | |||||
| return; | |||||
| } | |||||
| // Get dim info. | |||||
| if (AnfAlgo::HasNodeAttr(kAttrParallelDimInfo, tcnode)) { | |||||
| auto info = AnfAlgo::GetNodeAttr<std::vector<size_t>>(tcnode, kAttrParallelDimInfo); | |||||
| if (info.size() != 2) { | |||||
| MS_LOG(EXCEPTION) << "Parallel dim info is invalid!"; | |||||
| } | |||||
| auto tensor_name = | |||||
| GetTensorName(node_json_map.at(tmp_output), kJsonKeyOutputDesc, std::make_pair(0, tmp_output_index)); | |||||
| sub_graphs_info[info[0]].second.push_back(tensor_name); | |||||
| sub_graphs_info[info[0]].first = info[1]; | |||||
| } | |||||
| // Get fusion type. | |||||
| if (AnfAlgo::HasNodeAttr(kAttrParallelFusionType, tcnode)) { | |||||
| fusion_type = AnfAlgo::GetNodeAttr<std::string>(tcnode, kAttrParallelFusionType); | |||||
| } | |||||
| // Get fusion type info. | |||||
| if (AnfAlgo::HasNodeAttr(kAttrParallelTypeInfo, tcnode)) { | |||||
| type_info = AnfAlgo::GetNodeAttr<std::vector<std::vector<int>>>(tcnode, kAttrParallelTypeInfo); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (!sub_graphs_info.empty()) { | |||||
| auto processor = GetProcessorStr(anf_nodes[0]); | |||||
| if (processor != kProcessorCuda) { | |||||
| MS_LOG(EXCEPTION) << "Parallel fusion not support " << processor << " now."; | |||||
| } | |||||
| nlohmann::json parallel_fusion_json; | |||||
| parallel_fusion_json[kJsonKeyFusionType] = fusion_type; | |||||
| parallel_fusion_json[kJsonKeyTypeInfo] = type_info; | |||||
| std::vector<std::vector<std::string>> sgraphs; | |||||
| std::vector<size_t> cnums; | |||||
| std::for_each(sub_graphs_info.cbegin(), sub_graphs_info.cend(), | |||||
| [&sgraphs, &cnums](const std::pair<size_t, std::pair<size_t, std::vector<std::string>>> &sg_info) { | |||||
| sgraphs.push_back(sg_info.second.second); | |||||
| cnums.push_back(sg_info.second.first); | |||||
| }); | |||||
| parallel_fusion_json[kJsonKeySubGraph] = sgraphs; | |||||
| parallel_fusion_json[kJsonKeyCoreNum] = cnums; | |||||
| (*kernel_json)[kJsonKeyParallelFusion] = parallel_fusion_json; | |||||
| } | |||||
| } | |||||
| nlohmann::json AkgKernelJsonGenerator::CreateOutputsJson(const std::vector<AnfNodePtr> &anf_nodes, | nlohmann::json AkgKernelJsonGenerator::CreateOutputsJson(const std::vector<AnfNodePtr> &anf_nodes, | ||||
| const std::vector<AnfNodePtr> &input_list, | const std::vector<AnfNodePtr> &input_list, | ||||
| const std::vector<AnfNodePtr> &output_list, | const std::vector<AnfNodePtr> &output_list, | ||||
| @@ -785,17 +821,6 @@ nlohmann::json AkgKernelJsonGenerator::CreateOutputsJson(const std::vector<AnfNo | |||||
| output_shape.push_back(1); | output_shape.push_back(1); | ||||
| } | } | ||||
| output_desc_json[kJsonKeyShape] = output_shape; | output_desc_json[kJsonKeyShape] = output_shape; | ||||
| if (auto tcnode = tmp_output.first->cast<CNodePtr>(); | |||||
| tcnode && AnfAlgo::HasNodeAttr(kAttrParallelDimInfo, tcnode)) { | |||||
| auto info = AnfAlgo::GetNodeAttr<std::vector<size_t>>(tcnode, kAttrParallelDimInfo); | |||||
| if (info.size() != 2) { | |||||
| MS_LOG(EXCEPTION) << "Parallel dim info is invalid!"; | |||||
| } | |||||
| sub_graphs_[info[0]].push_back(output_desc_json[kJsonKeyTensorName]); | |||||
| if (dim_infos_.find(info[0]) == dim_infos_.end()) { | |||||
| dim_infos_[info[0]] = info[1]; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| outputs_json.emplace_back(output_desc_json); | outputs_json.emplace_back(output_desc_json); | ||||
| } | } | ||||
| @@ -54,6 +54,7 @@ constexpr auto kJsonKeyParallelFusion = "parallel_fusion"; | |||||
| constexpr auto kJsonKeyFusionType = "fusion_type"; | constexpr auto kJsonKeyFusionType = "fusion_type"; | ||||
| constexpr auto kJsonKeySubGraph = "sub_graph"; | constexpr auto kJsonKeySubGraph = "sub_graph"; | ||||
| constexpr auto kJsonKeyCoreNum = "core_num"; | constexpr auto kJsonKeyCoreNum = "core_num"; | ||||
| constexpr auto kJsonKeyTypeInfo = "type_info"; | |||||
| constexpr auto kJsonKeyBufferStitch = "buffer_stitch"; | constexpr auto kJsonKeyBufferStitch = "buffer_stitch"; | ||||
| constexpr auto kJsonKeyStitchOp = "stitch_op"; | constexpr auto kJsonKeyStitchOp = "stitch_op"; | ||||
| constexpr auto kJsonKeyStitchAtomicOp = "stitch_atomic_op"; | constexpr auto kJsonKeyStitchAtomicOp = "stitch_atomic_op"; | ||||
| @@ -89,8 +90,6 @@ class AkgKernelJsonGenerator { | |||||
| input_tensor_idx_.clear(); | input_tensor_idx_.clear(); | ||||
| address_node_map_.clear(); | address_node_map_.clear(); | ||||
| output_tensor_idx_ = 0; | output_tensor_idx_ = 0; | ||||
| sub_graphs_.clear(); | |||||
| dim_infos_.clear(); | |||||
| } | } | ||||
| void set_dump_option(DumpOption dump_option) { dump_option_ = dump_option; } | void set_dump_option(DumpOption dump_option) { dump_option_ = dump_option; } | ||||
| std::map<std::string, AnfNodePtr> address_node_map() { return address_node_map_; } | std::map<std::string, AnfNodePtr> address_node_map() { return address_node_map_; } | ||||
| @@ -127,9 +126,10 @@ class AkgKernelJsonGenerator { | |||||
| std::string GetOutputFormat(const AnfNodePtr &anf_node, size_t index); | std::string GetOutputFormat(const AnfNodePtr &anf_node, size_t index); | ||||
| void SaveNodeAddress(const AnfNodePtr &anf_node, nlohmann::json *node_json); | void SaveNodeAddress(const AnfNodePtr &anf_node, nlohmann::json *node_json); | ||||
| OpInfoPtr ExtractOpInfo(const AnfNodePtr &anf_node); | OpInfoPtr ExtractOpInfo(const AnfNodePtr &anf_node); | ||||
| void SetParallelValueToJson(const std::string &processor, const std::map<size_t, size_t> &dim_infos, | |||||
| nlohmann::json *sub_fusion_json); | |||||
| void AddParalleFusionJsonInfo(const std::string &processor, nlohmann::json *kernel_json); | |||||
| void CollectParallelDimInfo(const AnfNodePtr &anf_node); | |||||
| void GenParallelJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list, | |||||
| const std::vector<AnfNodePtr> &output_list, | |||||
| const std::map<AnfNodePtr, nlohmann::json> &node_json_map, nlohmann::json *kernel_json); | |||||
| DumpOption dump_option_; | DumpOption dump_option_; | ||||
| static int op_cnt_; | static int op_cnt_; | ||||
| @@ -142,8 +142,6 @@ class AkgKernelJsonGenerator { | |||||
| std::vector<size_t> input_size_list_; | std::vector<size_t> input_size_list_; | ||||
| std::vector<size_t> output_size_list_; | std::vector<size_t> output_size_list_; | ||||
| std::map<std::string, AnfNodePtr> address_node_map_; | std::map<std::string, AnfNodePtr> address_node_map_; | ||||
| std::map<size_t, std::vector<std::string>> sub_graphs_; | |||||
| std::map<size_t, size_t> dim_infos_; | |||||
| bool is_basic_op_{false}; | bool is_basic_op_{false}; | ||||
| }; | }; | ||||
| } // namespace kernel | } // namespace kernel | ||||
| @@ -60,8 +60,9 @@ KernelPackPtr AkgGpuKernelBuilder::OpBuild(const AkgKernelJsonGenerator &json_ge | |||||
| return cached_kernel_pack; | return cached_kernel_pack; | ||||
| } | } | ||||
| (void)alarm(AUTODIFF_COMPILE_OVERTIME); | |||||
| auto kernel_json = json_generator.kernel_json_str(); | auto kernel_json = json_generator.kernel_json_str(); | ||||
| kernel::SaveJsonInfo(kernel_name, kernel_json, kernel::KernelMeta::GetInstance()->kernel_meta_path()); | |||||
| (void)alarm(AUTODIFF_COMPILE_OVERTIME); | |||||
| auto res = GpuKernelBuildClient::Instance().AkgCompileSingle(kernel_json); | auto res = GpuKernelBuildClient::Instance().AkgCompileSingle(kernel_json); | ||||
| (void)alarm(0); | (void)alarm(0); | ||||
| if (!res) { | if (!res) { | ||||
| @@ -70,7 +71,6 @@ KernelPackPtr AkgGpuKernelBuilder::OpBuild(const AkgKernelJsonGenerator &json_ge | |||||
| } | } | ||||
| auto new_kernel_pack = InsertCache(kernel_name, processor); | auto new_kernel_pack = InsertCache(kernel_name, processor); | ||||
| kernel::SaveJsonInfo(kernel_name, kernel_json, kernel::KernelMeta::GetInstance()->kernel_meta_path()); | |||||
| if (new_kernel_pack == nullptr) { | if (new_kernel_pack == nullptr) { | ||||
| MS_LOG(ERROR) << "Insert to cache failed, kernel_name[" << kernel_name << "], fullname_with_scope[" | MS_LOG(ERROR) << "Insert to cache failed, kernel_name[" << kernel_name << "], fullname_with_scope[" | ||||
| << anf_node->fullname_with_scope() << "]."; | << anf_node->fullname_with_scope() << "]."; | ||||
| @@ -47,7 +47,7 @@ int ParallelCostModel::GetNodeCalAmount(const AnfNodePtr &node) { | |||||
| return py::cast<int>(ret); | return py::cast<int>(ret); | ||||
| } | } | ||||
| std::tuple<std::vector<DimInfoPtr>, int> ParallelCostModel::CalFuseInfo(const AnfNodePtrList &nodes) { | |||||
| std::tuple<std::vector<DimInfoPtr>, int, FusionInfoPtr> ParallelCostModel::CalFuseInfo(const AnfNodePtrList &nodes) { | |||||
| nlohmann::json json_desc; | nlohmann::json json_desc; | ||||
| std::vector<AnfNodePtrList> graphs; | std::vector<AnfNodePtrList> graphs; | ||||
| std::transform(nodes.begin(), nodes.end(), std::back_inserter(graphs), | std::transform(nodes.begin(), nodes.end(), std::back_inserter(graphs), | ||||
| @@ -65,7 +65,7 @@ std::tuple<std::vector<DimInfoPtr>, int> ParallelCostModel::CalFuseInfo(const An | |||||
| } | } | ||||
| py::tuple ret_tuple = py::cast<py::tuple>(ret); | py::tuple ret_tuple = py::cast<py::tuple>(ret); | ||||
| if (!py::isinstance<py::tuple>(ret_tuple) || ret_tuple.size() != 2) { | |||||
| if (!py::isinstance<py::tuple>(ret_tuple) || ret_tuple.size() != 4) { | |||||
| MS_LOG(EXCEPTION) << "Parallel cost model should return a tuple with two elements!"; | MS_LOG(EXCEPTION) << "Parallel cost model should return a tuple with two elements!"; | ||||
| } | } | ||||
| @@ -75,8 +75,41 @@ std::tuple<std::vector<DimInfoPtr>, int> ParallelCostModel::CalFuseInfo(const An | |||||
| dim_infos.push_back(std::make_shared<CommonDimInfo>(py::cast<int>(dim_list[i]))); | dim_infos.push_back(std::make_shared<CommonDimInfo>(py::cast<int>(dim_list[i]))); | ||||
| } | } | ||||
| int benefit = py::cast<int>(ret_tuple[1]); | int benefit = py::cast<int>(ret_tuple[1]); | ||||
| auto fusion_info = ProcessFusionInfo(ret_tuple[2], ret_tuple[3]); | |||||
| return std::make_tuple(dim_infos, benefit); | |||||
| return std::make_tuple(dim_infos, benefit, fusion_info); | |||||
| } | |||||
| FusionInfoPtr ParallelCostModel::ProcessFusionInfo(py::object fusion_type, py::object type_info) { | |||||
| if (!py::isinstance<py::str>(fusion_type)) { | |||||
| MS_LOG(EXCEPTION) << "Fusion type for parallel is invalid!"; | |||||
| } | |||||
| std::string fusion_type_name = py::cast<std::string>(fusion_type); | |||||
| FusionInfoPtr fusion_info; | |||||
| if (fusion_type_name == "block_fusion") { | |||||
| fusion_info = std::make_shared<BlockFusionInfo>(); | |||||
| } else if (fusion_type_name == "block_pipeline_fusion") { | |||||
| if (!py::isinstance<py::list>(type_info)) { | |||||
| MS_LOG(EXCEPTION) << "Fusion type info for block pipe fusion type is invalid!"; | |||||
| } | |||||
| std::vector<std::vector<int>> pipeline_ids; | |||||
| py::list pipeline_ids_list = py::cast<py::list>(type_info); | |||||
| for (size_t i = 0; i < pipeline_ids_list.size(); ++i) { | |||||
| std::vector<int> part_ids; | |||||
| py::list inner_ids_list = py::cast<py::list>(pipeline_ids_list[i]); | |||||
| for (size_t j = 0; j < inner_ids_list.size(); ++j) { | |||||
| part_ids.push_back(py::cast<int>(inner_ids_list[j])); | |||||
| } | |||||
| pipeline_ids.push_back(part_ids); | |||||
| } | |||||
| fusion_info = std::make_shared<BlockPipelineFusionInfo>(pipeline_ids); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Unsupported parallel fusion type: " << fusion_type_name; | |||||
| } | |||||
| return fusion_info; | |||||
| } | } | ||||
| ParallelCostModelPtr ParellelCostModelWarehouse::GetParallelCostModel(const std::string &target) { | ParallelCostModelPtr ParellelCostModelWarehouse::GetParallelCostModel(const std::string &target) { | ||||
| @@ -29,6 +29,7 @@ | |||||
| #include "backend/optimizer/common/optimizer.h" | #include "backend/optimizer/common/optimizer.h" | ||||
| #include "backend/optimizer/graph_kernel/parallel_cost_model.h" | #include "backend/optimizer/graph_kernel/parallel_cost_model.h" | ||||
| #include "backend/session/kernel_graph.h" | #include "backend/session/kernel_graph.h" | ||||
| #include "pipeline/jit/parse/python_adapter.h" | |||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -55,12 +56,50 @@ class CommonDimInfo : public DimInfo { | |||||
| using DimInfoPtr = std::shared_ptr<DimInfo>; | using DimInfoPtr = std::shared_ptr<DimInfo>; | ||||
| using CommonDimInfoPtr = std::shared_ptr<CommonDimInfo>; | using CommonDimInfoPtr = std::shared_ptr<CommonDimInfo>; | ||||
| class FusionInfo { | |||||
| public: | |||||
| FusionInfo() = default; | |||||
| explicit FusionInfo(const std::string &type) : fusion_type_(type) {} | |||||
| ~FusionInfo() = default; | |||||
| std::string FusionType() { return fusion_type_; } | |||||
| virtual bool ExistTypeInfo() { return false; } | |||||
| private: | |||||
| std::string fusion_type_{"none"}; | |||||
| }; | |||||
| class BlockFusionInfo : public FusionInfo { | |||||
| public: | |||||
| BlockFusionInfo() : FusionInfo("block_fusion") {} | |||||
| ~BlockFusionInfo() = default; | |||||
| bool ExistTypeInfo() { return false; } | |||||
| }; | |||||
| class BlockPipelineFusionInfo : public FusionInfo { | |||||
| public: | |||||
| explicit BlockPipelineFusionInfo(const std::vector<std::vector<int>> &ids) | |||||
| : FusionInfo("block_pipeline_fusion"), pipeline_ids_(ids) {} | |||||
| ~BlockPipelineFusionInfo() = default; | |||||
| bool ExistTypeInfo() { return true; } | |||||
| std::vector<std::vector<int>> PipelineIds() { return pipeline_ids_; } | |||||
| private: | |||||
| std::vector<std::vector<int>> pipeline_ids_; | |||||
| }; | |||||
| using FusionInfoPtr = std::shared_ptr<FusionInfo>; | |||||
| using BlockFusionInfoPtr = std::shared_ptr<BlockFusionInfo>; | |||||
| using BlockPipelineFusionInfoPtr = std::shared_ptr<BlockPipelineFusionInfo>; | |||||
| class ParallelCostModel { | class ParallelCostModel { | ||||
| public: | public: | ||||
| ParallelCostModel() {} | ParallelCostModel() {} | ||||
| ~ParallelCostModel() {} | ~ParallelCostModel() {} | ||||
| int GetNodeCalAmount(const AnfNodePtr &node); | int GetNodeCalAmount(const AnfNodePtr &node); | ||||
| std::tuple<std::vector<DimInfoPtr>, int> CalFuseInfo(const AnfNodePtrList &nodes); | |||||
| std::tuple<std::vector<DimInfoPtr>, int, FusionInfoPtr> CalFuseInfo(const AnfNodePtrList &nodes); | |||||
| private: | |||||
| FusionInfoPtr ProcessFusionInfo(py::object fusion_type, py::object type_info); | |||||
| }; | }; | ||||
| using ParallelCostModelPtr = std::shared_ptr<ParallelCostModel>; | using ParallelCostModelPtr = std::shared_ptr<ParallelCostModel>; | ||||
| @@ -553,7 +553,7 @@ std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::DoSea | |||||
| std::tie(other_candidates, std::ignore) = | std::tie(other_candidates, std::ignore) = | ||||
| GetAvaliableNodesByOffset(i, tc, sorted_candidates_used, candidates, std::set<int>()); | GetAvaliableNodesByOffset(i, tc, sorted_candidates_used, candidates, std::set<int>()); | ||||
| int benefit; | int benefit; | ||||
| std::tie(std::ignore, benefit) = cost_model_ptr_->CalFuseInfo(other_candidates); | |||||
| std::tie(std::ignore, benefit, std::ignore) = cost_model_ptr_->CalFuseInfo(other_candidates); | |||||
| if (benefit > 0) { | if (benefit > 0) { | ||||
| begin = mid + 1; | begin = mid + 1; | ||||
| } else { | } else { | ||||
| @@ -567,12 +567,12 @@ std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::DoSea | |||||
| AnfNodePtrList other_candidates; | AnfNodePtrList other_candidates; | ||||
| std::tie(other_candidates, std::ignore) = | std::tie(other_candidates, std::ignore) = | ||||
| GetAvaliableNodesByOffset(i, tc, sorted_candidates_used, candidates, std::set<int>()); | GetAvaliableNodesByOffset(i, tc, sorted_candidates_used, candidates, std::set<int>()); | ||||
| auto [dim_infos, benefit] = cost_model_ptr_->CalFuseInfo(other_candidates); | |||||
| auto [dim_infos, benefit, fusion_info] = cost_model_ptr_->CalFuseInfo(other_candidates); | |||||
| if (benefit <= 0) { | if (benefit <= 0) { | ||||
| MS_LOG(EXCEPTION) << "Internal error in candidate search!"; | MS_LOG(EXCEPTION) << "Internal error in candidate search!"; | ||||
| } | } | ||||
| max_benefit = benefit; | max_benefit = benefit; | ||||
| best_parallel_info = ParallelInfo(other_candidates, dim_infos); | |||||
| best_parallel_info = ParallelInfo(other_candidates, dim_infos, fusion_info); | |||||
| i += begin - 1; | i += begin - 1; | ||||
| } | } | ||||
| @@ -676,10 +676,13 @@ std::vector<ParallelInfo> ParallelOpFusion::SearchFusableParallelCNodes( | |||||
| } | } | ||||
| void ParallelOpFusion::SetFusedParallelOpAttrToReturnNode(const ParallelInfo ¶llel_info) { | void ParallelOpFusion::SetFusedParallelOpAttrToReturnNode(const ParallelInfo ¶llel_info) { | ||||
| AnfNodePtr attach_node; | |||||
| // Dim info should be attach to each segment's output. | |||||
| for (size_t i = 0; i < parallel_info.GetSize(); ++i) { | for (size_t i = 0; i < parallel_info.GetSize(); ++i) { | ||||
| const auto &fuse_nodes = parallel_info.nodes(); | const auto &fuse_nodes = parallel_info.nodes(); | ||||
| std::vector<size_t> info = {i, std::dynamic_pointer_cast<CommonDimInfo>(parallel_info.dims()[i])->dim_info()}; | std::vector<size_t> info = {i, std::dynamic_pointer_cast<CommonDimInfo>(parallel_info.dims()[i])->dim_info()}; | ||||
| if (!AnfAlgo::IsGraphKernel(fuse_nodes[i])) { | if (!AnfAlgo::IsGraphKernel(fuse_nodes[i])) { | ||||
| attach_node = fuse_nodes[i]; | |||||
| SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), fuse_nodes[i]); | SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), fuse_nodes[i]); | ||||
| } else { | } else { | ||||
| auto node_g = GetValueNode<FuncGraphPtr>((fuse_nodes[i]->cast<CNodePtr>())->input(0)); | auto node_g = GetValueNode<FuncGraphPtr>((fuse_nodes[i]->cast<CNodePtr>())->input(0)); | ||||
| @@ -689,11 +692,16 @@ void ParallelOpFusion::SetFusedParallelOpAttrToReturnNode(const ParallelInfo &pa | |||||
| for (size_t j = 1; j < inputs.size(); ++j) { | for (size_t j = 1; j < inputs.size(); ++j) { | ||||
| SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), inputs[j]); | SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), inputs[j]); | ||||
| } | } | ||||
| attach_node = inputs[1]; | |||||
| } else { | } else { | ||||
| attach_node = out_node; | |||||
| SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), out_node); | SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), out_node); | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| // Fusion info is ok to attach to one of the segments. | |||||
| SetFusionInfoAttrToNode(attach_node, parallel_info); | |||||
| } | } | ||||
| void PostProcessForNewSubGraphCNode(const AnfNodePtr &node, const std::shared_ptr<session::KernelGraph> &kernel_graph) { | void PostProcessForNewSubGraphCNode(const AnfNodePtr &node, const std::shared_ptr<session::KernelGraph> &kernel_graph) { | ||||
| @@ -741,6 +749,17 @@ void PostProcessForNewSubGraphCNode(const AnfNodePtr &node, const std::shared_pt | |||||
| } | } | ||||
| } | } | ||||
| void ParallelOpFusion::SetFusionInfoAttrToNode(const AnfNodePtr &node, const ParallelInfo ¶llel_info) { | |||||
| auto fusion_type = parallel_info.fusion_info()->FusionType(); | |||||
| AnfAlgo::SetNodeAttr(kAttrParallelFusionType, MakeValue<std::string>(fusion_type), node); | |||||
| if (parallel_info.fusion_info()->ExistTypeInfo()) { | |||||
| if (auto pipeline_fusion = std::dynamic_pointer_cast<BlockPipelineFusionInfo>(parallel_info.fusion_info())) { | |||||
| AnfAlgo::SetNodeAttr(kAttrParallelTypeInfo, | |||||
| MakeValue<std::vector<std::vector<int>>>(pipeline_fusion->PipelineIds()), node); | |||||
| } | |||||
| } | |||||
| } | |||||
| bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector<ParallelInfo> ¶llel_infos, | bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector<ParallelInfo> ¶llel_infos, | ||||
| const std::shared_ptr<session::KernelGraph> &kernel_graph) { | const std::shared_ptr<session::KernelGraph> &kernel_graph) { | ||||
| bool changed = false; | bool changed = false; | ||||
| @@ -755,6 +774,7 @@ bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector<ParallelInfo> | |||||
| AnfNodePtr sg_node; | AnfNodePtr sg_node; | ||||
| std::tie(sg_node, std::ignore) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "parallel"); | std::tie(sg_node, std::ignore) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "parallel"); | ||||
| PostProcessForNewSubGraphCNode(sg_node, kernel_graph); | PostProcessForNewSubGraphCNode(sg_node, kernel_graph); | ||||
| AnfAlgo::SetNodeAttr(kAttrCompositeType, MakeValue("parallel_fusion"), sg_node); | |||||
| DumpParallelFusionDetail(fuse_nodes, sg_node); | DumpParallelFusionDetail(fuse_nodes, sg_node); | ||||
| } | } | ||||
| @@ -37,10 +37,12 @@ namespace opt { | |||||
| class ParallelInfo { | class ParallelInfo { | ||||
| public: | public: | ||||
| ParallelInfo() = default; | ParallelInfo() = default; | ||||
| ParallelInfo(const AnfNodePtrList &nodes, const std::vector<DimInfoPtr> &dims) : nodes_(nodes), dims_(dims) {} | |||||
| ParallelInfo(const AnfNodePtrList &nodes, const std::vector<DimInfoPtr> &dims, const FusionInfoPtr &fusion_info) | |||||
| : nodes_(nodes), dims_(dims), fusion_info_(fusion_info) {} | |||||
| ParallelInfo(const ParallelInfo &obj) { | ParallelInfo(const ParallelInfo &obj) { | ||||
| nodes_ = obj.nodes_; | nodes_ = obj.nodes_; | ||||
| dims_ = obj.dims_; | dims_ = obj.dims_; | ||||
| fusion_info_ = obj.fusion_info_; | |||||
| } | } | ||||
| ~ParallelInfo() = default; | ~ParallelInfo() = default; | ||||
| @@ -52,10 +54,12 @@ class ParallelInfo { | |||||
| } | } | ||||
| const AnfNodePtrList &nodes() const { return nodes_; } | const AnfNodePtrList &nodes() const { return nodes_; } | ||||
| const std::vector<DimInfoPtr> &dims() const { return dims_; } | const std::vector<DimInfoPtr> &dims() const { return dims_; } | ||||
| const FusionInfoPtr &fusion_info() const { return fusion_info_; } | |||||
| private: | private: | ||||
| AnfNodePtrList nodes_; | AnfNodePtrList nodes_; | ||||
| std::vector<DimInfoPtr> dims_; | std::vector<DimInfoPtr> dims_; | ||||
| FusionInfoPtr fusion_info_; | |||||
| }; | }; | ||||
| class ParallelConfig { | class ParallelConfig { | ||||
| @@ -102,6 +106,8 @@ class ParallelOpFusion : public Pass { | |||||
| std::vector<ParallelInfo> SearchFusableParallelCNodes(const std::vector<std::vector<AnfNodePtrList>> &groups); | std::vector<ParallelInfo> SearchFusableParallelCNodes(const std::vector<std::vector<AnfNodePtrList>> &groups); | ||||
| void SetFusionInfoAttrToNode(const AnfNodePtr &node, const ParallelInfo ¶llel_info); | |||||
| void SetFusedParallelOpAttrToReturnNode(const ParallelInfo ¶llel_info); | void SetFusedParallelOpAttrToReturnNode(const ParallelInfo ¶llel_info); | ||||
| bool CreateParallelOpSubGraphs(const std::vector<ParallelInfo> ¶llel_infos, | bool CreateParallelOpSubGraphs(const std::vector<ParallelInfo> ¶llel_infos, | ||||
| @@ -397,6 +397,9 @@ constexpr auto kAttrIsGrad = "is_grad"; | |||||
| constexpr auto kAttrRecompute = "recompute"; | constexpr auto kAttrRecompute = "recompute"; | ||||
| constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute"; | constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute"; | ||||
| constexpr auto kAttrParallelDimInfo = "parallel_dim_info"; | constexpr auto kAttrParallelDimInfo = "parallel_dim_info"; | ||||
| constexpr auto kAttrParallelFusionType = "parallel_fusion_type"; | |||||
| constexpr auto kAttrParallelTypeInfo = "parallel_type_info"; | |||||
| constexpr auto kAttrCompositeType = "composite_type"; | |||||
| constexpr auto kAttrStitch = "stitch"; | constexpr auto kAttrStitch = "stitch"; | ||||
| constexpr auto kAttrTopoSortRhsFirst = "topo_sort_rhs_first"; | constexpr auto kAttrTopoSortRhsFirst = "topo_sort_rhs_first"; | ||||
| constexpr auto kAttrSwitchLayer = "switch_layer"; | constexpr auto kAttrSwitchLayer = "switch_layer"; | ||||