| @@ -13,6 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # =========================================================================== | # =========================================================================== | ||||
| """Cost model splitter""" | """Cost model splitter""" | ||||
| import os | |||||
| from functools import reduce | from functools import reduce | ||||
| from .model import PrimLib, Graph, Tensor | from .model import PrimLib, Graph, Tensor | ||||
| @@ -23,12 +24,19 @@ class GraphSplitByPattern: | |||||
| MODE_BASIC = 1 | MODE_BASIC = 1 | ||||
| MODE_COMPOSITE = 2 | MODE_COMPOSITE = 2 | ||||
| class StitchInfo: | |||||
| """StitchInfo""" | |||||
| def __init__(self): | |||||
| self.stitch_ops = set() | |||||
| self.stitch_atomic_ops = set() | |||||
| def __init__(self, init_op, is_output): | def __init__(self, init_op, is_output): | ||||
| self.pattern = PrimLib.iter_type(init_op) | self.pattern = PrimLib.iter_type(init_op) | ||||
| self.ops = [init_op] | self.ops = [init_op] | ||||
| self.in_relations = dict() # {area1: relation1, area2: relation2, ...} | self.in_relations = dict() # {area1: relation1, area2: relation2, ...} | ||||
| self.out_relations = dict() # {area1: relation1, area2: relation2, ...} | self.out_relations = dict() # {area1: relation1, area2: relation2, ...} | ||||
| self.mode = None | self.mode = None | ||||
| self.stitch_info = self.StitchInfo() | |||||
| self.is_output = is_output | self.is_output = is_output | ||||
| self.output_excluded = set() | self.output_excluded = set() | ||||
| if self.pattern == PrimLib.REDUCE: | if self.pattern == PrimLib.REDUCE: | ||||
| @@ -69,6 +77,12 @@ class GraphSplitByPattern: | |||||
| for input_area, r in self.in_relations.items(): | for input_area, r in self.in_relations.items(): | ||||
| input_area.out_relations[self] = r | input_area.out_relations[self] = r | ||||
| def update_stitch_info(self, stitch_info): | |||||
| if stitch_info.stitch_ops: | |||||
| self.stitch_info.stitch_ops.update(stitch_info.stitch_ops) | |||||
| if stitch_info.stitch_atomic_ops: | |||||
| self.stitch_info.stitch_atomic_ops.update(stitch_info.stitch_atomic_ops) | |||||
| def fuse(self, area): | def fuse(self, area): | ||||
| """Fuse `area` to `self`""" | """Fuse `area` to `self`""" | ||||
| def _update_relation(relations, a, r): | def _update_relation(relations, a, r): | ||||
| @@ -107,6 +121,7 @@ class GraphSplitByPattern: | |||||
| self.is_output = True | self.is_output = True | ||||
| if area.output_excluded: | if area.output_excluded: | ||||
| self.output_excluded.update(area.output_excluded) | self.output_excluded.update(area.output_excluded) | ||||
| self.update_stitch_info(area.stitch_info) | |||||
| def check_circle(self, to): | def check_circle(self, to): | ||||
| """Check circle. It returns false if circle exists""" | """Check circle. It returns false if circle exists""" | ||||
| @@ -181,10 +196,25 @@ class GraphSplitByPattern: | |||||
| graphmodes = [] | graphmodes = [] | ||||
| for i, area in enumerate(self.areas): | for i, area in enumerate(self.areas): | ||||
| area.ops.sort(key=lambda op: ids[op]) | area.ops.sort(key=lambda op: ids[op]) | ||||
| subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops)) | |||||
| subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops, area.stitch_info)) | |||||
| graphmodes.append("basic" if area.mode == self.Area.MODE_BASIC else "composite") | graphmodes.append("basic" if area.mode == self.Area.MODE_BASIC else "composite") | ||||
| return subgraphs, graphmodes | return subgraphs, graphmodes | ||||
| def dump_subgraphs(self, subgraphs): | |||||
| """Dump subgraphs""" | |||||
| if os.environ.get("ENABLE_SUBGRAPHS", "off") == "on": | |||||
| subgraphs_str = "subgraphs:\nlen: " + str(len(subgraphs)) + "\n" | |||||
| for i, sub in enumerate(subgraphs): | |||||
| subgraphs_str += str("============") + str(i) + "\n" | |||||
| subgraphs_str += str(sub) | |||||
| dirname = 'subgraphs' | |||||
| if not os.path.exists(dirname): | |||||
| os.makedirs(dirname) | |||||
| graphname = self.graph.name | |||||
| filename = dirname + '/' + graphname + '.log' | |||||
| with open(filename, 'w') as f: | |||||
| f.write(subgraphs_str) | |||||
| def split(self): | def split(self): | ||||
| """Split graph by pattern""" | """Split graph by pattern""" | ||||
| self.do_split() | self.do_split() | ||||
| @@ -192,6 +222,7 @@ class GraphSplitByPattern: | |||||
| # Note: after this function, the input output relation is not maintained. | # Note: after this function, the input output relation is not maintained. | ||||
| self.split_output_reshapes() | self.split_output_reshapes() | ||||
| subgraphs, graphmodes = self.to_subgraphs() | subgraphs, graphmodes = self.to_subgraphs() | ||||
| self.dump_subgraphs(subgraphs) | |||||
| return subgraphs, graphmodes | return subgraphs, graphmodes | ||||
| def split_output_reshapes(self): | def split_output_reshapes(self): | ||||
| @@ -362,15 +393,25 @@ class GraphSplitGpu(GraphSplitByPattern): | |||||
| return reduce_size >= 1024 | return reduce_size >= 1024 | ||||
| return True | return True | ||||
| def _reduce_nums(ops): | |||||
| count = 0 | |||||
| for op in ops: | |||||
| if op.prim.startswith('Reduce'): | |||||
| count += 1 | |||||
| return count | |||||
| def _reduce_output(dom): | def _reduce_output(dom): | ||||
| if dom.pattern != PrimLib.REDUCE: | if dom.pattern != PrimLib.REDUCE: | ||||
| return None | return None | ||||
| if _reduce_nums(dom.ops) > 1: | |||||
| return None | |||||
| if _is_atomic_add_available(dom): | if _is_atomic_add_available(dom): | ||||
| return None | return None | ||||
| is_all_reduce = _tensor_size(dom.ops[0].output) == 1 | is_all_reduce = _tensor_size(dom.ops[0].output) == 1 | ||||
| # excluded large size all reduce | # excluded large size all reduce | ||||
| if is_all_reduce and _tensor_size(dom.ops[0].inputs[0]) > 1024 * 12: | if is_all_reduce and _tensor_size(dom.ops[0].inputs[0]) > 1024 * 12: | ||||
| return None | return None | ||||
| fused = [] | fused = [] | ||||
| for a, r in dom.out_relations.items(): | for a, r in dom.out_relations.items(): | ||||
| if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \ | if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \ | ||||
| @@ -378,6 +419,24 @@ class GraphSplitGpu(GraphSplitByPattern): | |||||
| fused.append(a) | fused.append(a) | ||||
| return fused, False | return fused, False | ||||
| def _reduce_stitch(dom): | |||||
| if dom.pattern != PrimLib.REDUCE: | |||||
| return None | |||||
| if _tensor_size(dom.ops[0].output) == 1: | |||||
| return None | |||||
| if _tensor_size(dom.ops[0].inputs[0]) < 1024 * 12: | |||||
| return None | |||||
| fused = [] | |||||
| for a, r in dom.out_relations.items(): | |||||
| if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_circle(a): | |||||
| if _reduce_nums(a.ops) < 2: | |||||
| # softmax | |||||
| if len(a.ops) > 4 and len(a.ops[0].inputs[0].shape) == 4: | |||||
| dom.stitch_info.stitch_ops.add(dom.ops[0].output.name) | |||||
| fused.append(a) | |||||
| return fused, False | |||||
| def _transpose(dom): | def _transpose(dom): | ||||
| if len(dom.ops) != 1 or dom.ops[0].prim != "Transpose": | if len(dom.ops) != 1 or dom.ops[0].prim != "Transpose": | ||||
| return None | return None | ||||
| @@ -398,6 +457,7 @@ class GraphSplitGpu(GraphSplitByPattern): | |||||
| changed = self.fuse(_broadcast_width) or changed | changed = self.fuse(_broadcast_width) or changed | ||||
| if use_poly_reduce: | if use_poly_reduce: | ||||
| changed = self.fuse(_reduce_output) or changed | changed = self.fuse(_reduce_output) or changed | ||||
| changed = self.fuse(_reduce_stitch) or changed | |||||
| self.fuse(_transpose) | self.fuse(_transpose) | ||||
| class GraphSplitAscend(GraphSplitByPattern): | class GraphSplitAscend(GraphSplitByPattern): | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -310,11 +310,12 @@ class Operator: | |||||
| class Graph: | class Graph: | ||||
| """Graph""" | """Graph""" | ||||
| def __init__(self, name, ops): | |||||
| def __init__(self, name, ops, stitch_info=None): | |||||
| self.name = name | self.name = name | ||||
| self.ops = ops # in topo order, can not use set | self.ops = ops # in topo order, can not use set | ||||
| self.inputs = [] | self.inputs = [] | ||||
| self.outputs = [] | self.outputs = [] | ||||
| self.stitch_info = stitch_info | |||||
| def set_processor(self, processor): | def set_processor(self, processor): | ||||
| """Set processor""" | """Set processor""" | ||||
| @@ -372,6 +373,12 @@ class Graph: | |||||
| out_str = ', '.join([repr(t) for t in outputs]) | out_str = ', '.join([repr(t) for t in outputs]) | ||||
| lines = [] | lines = [] | ||||
| lines.append("%s(%s) -> %s {" % (self.name, para_str, out_str)) | lines.append("%s(%s) -> %s {" % (self.name, para_str, out_str)) | ||||
| if self.stitch_info: | |||||
| if self.stitch_info.stitch_ops: | |||||
| lines.append(' stitch -> ' + str(self.stitch_info.stitch_ops)) | |||||
| if self.stitch_info.stitch_atomic_ops: | |||||
| lines.append(' stitch_atomic_ops-> ' + str(self.stitch_info.stitch_atomic_ops)) | |||||
| for op in self.ops: | for op in self.ops: | ||||
| lines.append(' ' + str(op)) | lines.append(' ' + str(op)) | ||||
| lines.append('}') | lines.append('}') | ||||
| @@ -405,12 +412,20 @@ class Graph: | |||||
| in_desc.append([{'data_type': t.dtype, 'value': t.value, 'name': '', 'shape': t.shape, | in_desc.append([{'data_type': t.dtype, 'value': t.value, 'name': '', 'shape': t.shape, | ||||
| 'tensor_name': t.name, 'format': t.data_format}]) | 'tensor_name': t.name, 'format': t.data_format}]) | ||||
| out_desc = [{'data_type': op.output.dtype, 'name': '', 'shape': op.output.shape, | out_desc = [{'data_type': op.output.dtype, 'name': '', 'shape': op.output.shape, | ||||
| 'tensor_name': op.output.name, 'format': t.data_format}] | |||||
| 'tensor_name': op.output.name, 'format': op.output.data_format}] | |||||
| op_desc.append({'attr': attrs, 'impl_path': '', | op_desc.append({'attr': attrs, 'impl_path': '', | ||||
| 'input_desc': in_desc, 'name': op.prim, 'output_desc': out_desc}) | 'input_desc': in_desc, 'name': op.prim, 'output_desc': out_desc}) | ||||
| graph_desc = {'composite': True, 'composite_graph': '', 'id': 0, | graph_desc = {'composite': True, 'composite_graph': '', 'id': 0, | ||||
| 'input_desc': input_desc, 'op': self.name, 'op_desc': op_desc, 'output_desc': output_desc, | 'input_desc': input_desc, 'op': self.name, 'op_desc': op_desc, 'output_desc': output_desc, | ||||
| 'platform': 'AKG', 'process': self.processor} | 'platform': 'AKG', 'process': self.processor} | ||||
| if self.stitch_info and self.stitch_info.stitch_ops: | |||||
| buffer_stitch = {'stitch_op': list(self.stitch_info.stitch_ops)} | |||||
| if self.stitch_info.stitch_atomic_ops: | |||||
| buffer_stitch['stitch_atomic_op'] = list(self.stitch_info.stitch_atomic_ops) | |||||
| graph_desc['buffer_stitch'] = buffer_stitch | |||||
| return graph_desc | return graph_desc | ||||
| @@ -313,6 +313,14 @@ class CompositeGraph: | |||||
| self.graph = builder.get()[0] | self.graph = builder.get()[0] | ||||
| self.desc = desc | self.desc = desc | ||||
| def add_stitch_info(self, subgraph, desc): | |||||
| if subgraph.stitch_info and subgraph.stitch_info.stitch_ops: | |||||
| buffer_stitch = {'stitch_op': list(subgraph.stitch_info.stitch_ops)} | |||||
| if subgraph.stitch_info.stitch_atomic_ops: | |||||
| buffer_stitch['stitch_atomic_op'] = list(subgraph.stitch_info.stitch_atomic_ops) | |||||
| desc['buffer_stitch'] = buffer_stitch | |||||
| return desc | |||||
| def dump(self, subgraph): | def dump(self, subgraph): | ||||
| """Dump Graph to json""" | """Dump Graph to json""" | ||||
| desc = {} | desc = {} | ||||
| @@ -368,6 +376,8 @@ class CompositeGraph: | |||||
| desc[key] = subgraph.name | desc[key] = subgraph.name | ||||
| else: | else: | ||||
| desc[key] = self.desc[key] | desc[key] = self.desc[key] | ||||
| desc = self.add_stitch_info(subgraph, desc) | |||||
| return desc | return desc | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -414,6 +414,35 @@ FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const std::string &kernel_js | |||||
| return DecodeFusedNodes(kernel_json); | return DecodeFusedNodes(kernel_json); | ||||
| } | } | ||||
| StitchInfo AkgKernelJsonDecoder::GetStitchInfo(const nlohmann::json &kernel_json) { | |||||
| StitchInfo info; | |||||
| if (kernel_json.find(kJsonKeyBufferStitch) != kernel_json.end()) { | |||||
| nlohmann::json buffer_stitch = kernel_json[kJsonKeyBufferStitch]; | |||||
| if (buffer_stitch.find(kJsonKeyStitchOp) != buffer_stitch.end()) { | |||||
| std::vector<std::string> stitch_ops = buffer_stitch[kJsonKeyStitchOp]; | |||||
| info.stitch_ops = stitch_ops; | |||||
| } | |||||
| if (buffer_stitch.find(kJsonKeyStitchAtomicOp) != buffer_stitch.end()) { | |||||
| std::vector<std::string> stitch_atomic_ops = buffer_stitch[kJsonKeyStitchAtomicOp]; | |||||
| info.stitch_atomic_ops = stitch_atomic_ops; | |||||
| } | |||||
| } | |||||
| return info; | |||||
| } | |||||
| void AkgKernelJsonDecoder::SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node) { | |||||
| std::vector<nlohmann::json> output_descs = op_desc[kJsonKeyOutputDesc]; | |||||
| if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) return; | |||||
| std::string tensor_name = output_descs[0][kJsonKeyTensorName]; | |||||
| if (std::find(info.stitch_ops.begin(), info.stitch_ops.end(), tensor_name) != info.stitch_ops.end()) { | |||||
| AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("common"), node); | |||||
| } | |||||
| if (std::find(info.stitch_atomic_ops.begin(), info.stitch_atomic_ops.end(), tensor_name) != | |||||
| info.stitch_atomic_ops.end()) { | |||||
| AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("atomic"), node); | |||||
| } | |||||
| } | |||||
| bool AkgKernelJsonDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json, | bool AkgKernelJsonDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json, | ||||
| const std::map<std::string, AnfNodePtr> &address_node_map, | const std::map<std::string, AnfNodePtr> &address_node_map, | ||||
| AnfNodePtrList *res_graphs) { | AnfNodePtrList *res_graphs) { | ||||
| @@ -425,6 +454,7 @@ bool AkgKernelJsonDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json, | |||||
| MS_LOG(ERROR) << "Error decode, no cnodes for graph." << kernel_json; | MS_LOG(ERROR) << "Error decode, no cnodes for graph." << kernel_json; | ||||
| return false; | return false; | ||||
| } | } | ||||
| StitchInfo info = GetStitchInfo(kernel_json); | |||||
| for (const auto &op_desc : op_node_descs) { | for (const auto &op_desc : op_node_descs) { | ||||
| if (op_desc.find(kJsonKeyPtrAddress) == op_desc.end() || op_desc[kJsonKeyPtrAddress].is_null()) { | if (op_desc.find(kJsonKeyPtrAddress) == op_desc.end() || op_desc[kJsonKeyPtrAddress].is_null()) { | ||||
| MS_LOG(ERROR) << "Decode failed, key: " << kJsonKeyPtrAddress << " not found in: " << op_desc; | MS_LOG(ERROR) << "Decode failed, key: " << kJsonKeyPtrAddress << " not found in: " << op_desc; | ||||
| @@ -436,7 +466,9 @@ bool AkgKernelJsonDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json, | |||||
| MS_LOG(ERROR) << "Decode failed, ptr_address not found in map."; | MS_LOG(ERROR) << "Decode failed, ptr_address not found in map."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| res_graphs->push_back(address_node_map.at(ptr_address)); | |||||
| auto node = address_node_map.at(ptr_address)->cast<CNodePtr>(); | |||||
| SetStitchAttr(op_desc, info, node); | |||||
| res_graphs->push_back(node); | |||||
| } | } | ||||
| MS_LOG(DEBUG) << "decode cnodes success, size: " << res_graphs->size(); | MS_LOG(DEBUG) << "decode cnodes success, size: " << res_graphs->size(); | ||||
| return true; | return true; | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -26,6 +26,10 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| struct StitchInfo { | |||||
| std::vector<std::string> stitch_ops; | |||||
| std::vector<std::string> stitch_atomic_ops; | |||||
| }; | |||||
| class AkgKernelJsonDecoder { | class AkgKernelJsonDecoder { | ||||
| public: | public: | ||||
| AkgKernelJsonDecoder() { nodes_map_.clear(); } | AkgKernelJsonDecoder() { nodes_map_.clear(); } | ||||
| @@ -40,6 +44,8 @@ class AkgKernelJsonDecoder { | |||||
| ParameterPtr DecodeParameter(const nlohmann::json ¶meter_json, const FuncGraphPtr &func_graph); | ParameterPtr DecodeParameter(const nlohmann::json ¶meter_json, const FuncGraphPtr &func_graph); | ||||
| CNodePtr DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, const std::string &processor); | CNodePtr DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, const std::string &processor); | ||||
| AnfNodePtr DecodeOutput(const std::vector<nlohmann::json> &output_descs, const FuncGraphPtr &func_graph); | AnfNodePtr DecodeOutput(const std::vector<nlohmann::json> &output_descs, const FuncGraphPtr &func_graph); | ||||
| StitchInfo GetStitchInfo(const nlohmann::json &kernel_json); | |||||
| void SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node); | |||||
| std::map<std::string, AnfNodePtr> nodes_map_; | std::map<std::string, AnfNodePtr> nodes_map_; | ||||
| }; | }; | ||||
| } // namespace kernel | } // namespace kernel | ||||
| @@ -580,6 +580,31 @@ void AkgKernelJsonGenerator::AddParalleFusionJsonInfo(const std::string &process | |||||
| (*kernel_json)[kJsonKeyParallelFusion] = parallel_fusion_json; | (*kernel_json)[kJsonKeyParallelFusion] = parallel_fusion_json; | ||||
| } | } | ||||
| void AkgKernelJsonGenerator::GenStitchJson(const std::vector<AnfNodePtr> &anf_nodes, | |||||
| std::map<AnfNodePtr, nlohmann::json> *node_json_map, | |||||
| nlohmann::json *kernel_json) { | |||||
| std::vector<std::string> stitchs; | |||||
| for (auto const &anf_node : anf_nodes) { | |||||
| if (AnfAlgo::HasNodeAttr(kAttrStitch, anf_node->cast<CNodePtr>()) && | |||||
| AnfAlgo::GetNodeAttr<std::string>(anf_node, kAttrStitch) == "common") { | |||||
| auto name = GetTensorName((*node_json_map)[anf_node], kJsonKeyOutputDesc, {0, 0}); | |||||
| if (std::find(stitchs.begin(), stitchs.end(), name) == stitchs.end()) { | |||||
| stitchs.emplace_back(name); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (!stitchs.empty()) { | |||||
| std::vector<nlohmann::json> v; | |||||
| for (auto &s : stitchs) { | |||||
| std::vector<std::string> t; | |||||
| t.emplace_back(s); | |||||
| v.emplace_back(t); | |||||
| } | |||||
| nlohmann::json stitch_json; | |||||
| stitch_json[kJsonKeyStitchOp] = v; | |||||
| (*kernel_json)[kJsonKeyBufferStitch] = stitch_json; | |||||
| } | |||||
| } | |||||
| bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes, | bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes, | ||||
| const std::vector<AnfNodePtr> &input_list, | const std::vector<AnfNodePtr> &input_list, | ||||
| const std::vector<AnfNodePtr> &output_list, nlohmann::json *kernel_json) { | const std::vector<AnfNodePtr> &output_list, nlohmann::json *kernel_json) { | ||||
| @@ -637,6 +662,8 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf | |||||
| (*kernel_json)[kJsonKeyComposite] = true; | (*kernel_json)[kJsonKeyComposite] = true; | ||||
| (*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString() + "." + fg->debug_info()->get_id(); | (*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString() + "." + fg->debug_info()->get_id(); | ||||
| GenStitchJson(anf_nodes, &node_json_map, kernel_json); | |||||
| if (!GetIOSize(*kernel_json, &input_size_list_, &output_size_list_)) { | if (!GetIOSize(*kernel_json, &input_size_list_, &output_size_list_)) { | ||||
| MS_LOG(ERROR) << "Cal mem size failed."; | MS_LOG(ERROR) << "Cal mem size failed."; | ||||
| return false; | return false; | ||||
| @@ -54,6 +54,9 @@ constexpr auto kJsonKeyParallelFusion = "parallel_fusion"; | |||||
| constexpr auto kJsonKeyFusionType = "fusion_type"; | constexpr auto kJsonKeyFusionType = "fusion_type"; | ||||
| constexpr auto kJsonKeySubGraph = "sub_graph"; | constexpr auto kJsonKeySubGraph = "sub_graph"; | ||||
| constexpr auto kJsonKeyCoreNum = "core_num"; | constexpr auto kJsonKeyCoreNum = "core_num"; | ||||
| constexpr auto kJsonKeyBufferStitch = "buffer_stitch"; | |||||
| constexpr auto kJsonKeyStitchOp = "stitch_op"; | |||||
| constexpr auto kJsonKeyStitchAtomicOp = "stitch_atomic_op"; | |||||
| constexpr auto kAttrInputNames = "input_names"; | constexpr auto kAttrInputNames = "input_names"; | ||||
| @@ -98,6 +101,8 @@ class AkgKernelJsonGenerator { | |||||
| void GetAttrJson(const AnfNodePtr &anf_node, const std::vector<int> &dyn_input_sizes, const OpAttrPtr &op_attr, | void GetAttrJson(const AnfNodePtr &anf_node, const std::vector<int> &dyn_input_sizes, const OpAttrPtr &op_attr, | ||||
| nlohmann::json *attr_json, const ValuePtr &attr_value); | nlohmann::json *attr_json, const ValuePtr &attr_value); | ||||
| bool CreateAttrDescJson(const AnfNodePtr &anf_node, const OpInfoPtr &op_info, nlohmann::json *attrs_json); | bool CreateAttrDescJson(const AnfNodePtr &anf_node, const OpInfoPtr &op_info, nlohmann::json *attrs_json); | ||||
| void GenStitchJson(const std::vector<AnfNodePtr> &anf_nodes, std::map<AnfNodePtr, nlohmann::json> *node_json_map, | |||||
| nlohmann::json *kernel_json); | |||||
| bool GetIOSize(const nlohmann::json &node_json, std::vector<size_t> *input_size, std::vector<size_t> *output_size); | bool GetIOSize(const nlohmann::json &node_json, std::vector<size_t> *input_size, std::vector<size_t> *output_size); | ||||
| bool GenSingleJsons(const std::vector<AnfNodePtr> &anf_nodes, std::map<AnfNodePtr, nlohmann::json> *node_json_map); | bool GenSingleJsons(const std::vector<AnfNodePtr> &anf_nodes, std::map<AnfNodePtr, nlohmann::json> *node_json_map); | ||||
| void UpdateTensorName(const std::vector<AnfNodePtr> &anf_nodes, std::map<AnfNodePtr, nlohmann::json> *node_json_map); | void UpdateTensorName(const std::vector<AnfNodePtr> &anf_nodes, std::map<AnfNodePtr, nlohmann::json> *node_json_map); | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -21,6 +21,7 @@ | |||||
| #include <tuple> | #include <tuple> | ||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include <string> | |||||
| #include "backend/optimizer/common/optimizer.h" | #include "backend/optimizer/common/optimizer.h" | ||||
| #include "backend/session/kernel_graph.h" | #include "backend/session/kernel_graph.h" | ||||
| @@ -28,21 +29,24 @@ namespace mindspore { | |||||
| namespace opt { | namespace opt { | ||||
| class AtomicCleanInsertter : public Pass { | class AtomicCleanInsertter : public Pass { | ||||
| public: | public: | ||||
| AtomicCleanInsertter() : Pass("atomic_clean") {} | |||||
| explicit AtomicCleanInsertter(const std::string &name = "atomic_clean") : Pass(name) {} | |||||
| ~AtomicCleanInsertter() override = default; | ~AtomicCleanInsertter() override = default; | ||||
| bool Run(const FuncGraphPtr &func_graph) override; | |||||
| virtual bool Run(const FuncGraphPtr &func_graph); | |||||
| private: | |||||
| void ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input, | |||||
| const FuncGraphManagerPtr &mng); | |||||
| bool CanActivateAtomicAdd(const AnfNodePtr &anf_node); | |||||
| protected: | |||||
| virtual void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input); | |||||
| virtual void ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input, | |||||
| const FuncGraphManagerPtr &mng); | |||||
| void InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node, const FuncGraphManagerPtr &mng); | void InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node, const FuncGraphManagerPtr &mng); | ||||
| void AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, const AnfNodePtr &composite_node, | void AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, const AnfNodePtr &composite_node, | ||||
| const AnfNodePtr &user_node, int index); | const AnfNodePtr &user_node, int index); | ||||
| void CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter); | |||||
| CNodePtr atomic_add_node_{nullptr}; | |||||
| private: | |||||
| bool CanActivateAtomicAdd(const AnfNodePtr &anf_node); | |||||
| void CorrectAbstract(const AnfNodePtr &composite_node); | void CorrectAbstract(const AnfNodePtr &composite_node); | ||||
| void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input); | |||||
| CNodePtr CreateAtomicCleanCompositeNode(const KernelGraphPtr &main_graph, TypeId dst_type); | CNodePtr CreateAtomicCleanCompositeNode(const KernelGraphPtr &main_graph, TypeId dst_type); | ||||
| void CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter); | |||||
| void ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node, | void ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node, | ||||
| const AnfNodePtr &broadcast_to_node, const FuncGraphManagerPtr &mng); | const AnfNodePtr &broadcast_to_node, const FuncGraphManagerPtr &mng); | ||||
| std::tuple<AnfNodePtr, AnfNodePtr, int> FindPatronNode(const KernelGraphPtr &main_graph); | std::tuple<AnfNodePtr, AnfNodePtr, int> FindPatronNode(const KernelGraphPtr &main_graph); | ||||
| @@ -55,7 +59,6 @@ class AtomicCleanInsertter : public Pass { | |||||
| bool IsExistStructuralObstacle(const KernelGraphPtr &main_graph, const AnfNodePtr &node, | bool IsExistStructuralObstacle(const KernelGraphPtr &main_graph, const AnfNodePtr &node, | ||||
| const FuncGraphManagerPtr &mng); | const FuncGraphManagerPtr &mng); | ||||
| CNodePtr atomic_add_node_{nullptr}; | |||||
| size_t reduce_real_output_index_{0}; | size_t reduce_real_output_index_{0}; | ||||
| size_t real_output_num_{0}; | size_t real_output_num_{0}; | ||||
| std::vector<std::pair<AnfNodePtr, AnfNodePtr>> to_process_order_; | std::vector<std::pair<AnfNodePtr, AnfNodePtr>> to_process_order_; | ||||
| @@ -0,0 +1,200 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h" | |||||
| #include <algorithm> | |||||
| #include <functional> | |||||
| #include <list> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <utility> | |||||
| #include <set> | |||||
| #include <stack> | |||||
| #include <string> | |||||
| #include <tuple> | |||||
| #include <vector> | |||||
| #include "base/core_ops.h" | |||||
| #include "ir/tensor.h" | |||||
| #include "utils/utils.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "backend/kernel_compiler/kernel.h" | |||||
| #include "backend/kernel_compiler/common_utils.h" | |||||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "backend/session/kernel_graph.h" | |||||
| #include "debug/anf_ir_dump.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| void StitchAtomicCleanInsertter::CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input) { | |||||
| // Change kernel build info. | |||||
| auto kernel_info = static_cast<device::KernelInfo *>(composite_node->kernel_info()); | |||||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||||
| const auto &origin_kernel_build_info = kernel_info->GetMutableSelectKernelBuildInfo(); | |||||
| auto origin_inputs_format = origin_kernel_build_info->GetAllInputFormats(); | |||||
| auto origin_outputs_format = origin_kernel_build_info->GetAllOutputFormats(); | |||||
| auto origin_inputs_type = origin_kernel_build_info->GetAllInputDeviceTypes(); | |||||
| auto origin_outputs_type = origin_kernel_build_info->GetAllOutputDeviceTypes(); | |||||
| auto origin_processor = origin_kernel_build_info->processor(); | |||||
| std::vector<std::string> &new_inputs_format = origin_inputs_format; | |||||
| std::vector<TypeId> &new_inputs_type = origin_inputs_type; | |||||
| std::vector<std::string> new_outputs_format; | |||||
| std::vector<TypeId> new_outputs_type; | |||||
| for (size_t i = 0; i < origin_outputs_format.size(); ++i) { | |||||
| new_outputs_format.push_back(origin_outputs_format[i]); | |||||
| new_outputs_type.push_back(origin_outputs_type[i]); | |||||
| } | |||||
| auto kernel_with_index = AnfAlgo::VisitKernel(new_input, 0); | |||||
| new_inputs_format.push_back(AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second)); | |||||
| new_inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second)); | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder new_info_builder; | |||||
| new_info_builder.SetInputsFormat(new_inputs_format); | |||||
| new_info_builder.SetInputsDeviceType(new_inputs_type); | |||||
| new_info_builder.SetOutputsFormat(new_outputs_format); | |||||
| new_info_builder.SetOutputsDeviceType(new_outputs_type); | |||||
| new_info_builder.SetProcessor(origin_processor); | |||||
| new_info_builder.SetKernelType(KernelType::AKG_KERNEL); | |||||
| new_info_builder.SetFusionType(kernel::FusionType::OPAQUE); | |||||
| auto new_selected_info = new_info_builder.Build(); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get()); | |||||
| } | |||||
| CNodePtr StitchAtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, | |||||
| const AnfNodePtr &new_parameter) { | |||||
| // add inplaceassign | |||||
| AnfNodePtr out_node = atomic_add_node_; // Use result data itself, and set attr "fake_out" true. | |||||
| auto inplace_assign_node = | |||||
| CreateCNode({NewValueNode(std::make_shared<Primitive>("InplaceAssign")), new_parameter, atomic_add_node_, out_node}, | |||||
| sub_graph, {.format = GetFormat(out_node), .shape = GetShape(out_node), .type = GetType(out_node)}); | |||||
| AnfAlgo::SetNodeAttr("fake_output", MakeValue(true), inplace_assign_node); | |||||
| AnfAlgo::EraseNodeAttr(kAttrStitch, atomic_add_node_); | |||||
| AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("common"), inplace_assign_node); | |||||
| return inplace_assign_node; | |||||
| } | |||||
| void StitchAtomicCleanInsertter::ProcessOriginCNode(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node, | |||||
| const AnfNodePtr &new_input, const FuncGraphManagerPtr &mng) { | |||||
| auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(composite_node); | |||||
| auto mng_sub = sub_graph->manager(); | |||||
| if (mng_sub == nullptr) { | |||||
| mng_sub = Manage(sub_graph, false); | |||||
| sub_graph->set_manager(mng_sub); | |||||
| } | |||||
| // add input | |||||
| auto inputs = composite_node->cast<CNodePtr>()->inputs(); | |||||
| inputs.push_back(new_input); | |||||
| composite_node->cast<CNodePtr>()->set_inputs(inputs); | |||||
| // add parameter | |||||
| auto parameter = sub_graph->add_parameter(); | |||||
| parameter->set_abstract(new_input->abstract()); | |||||
| parameter->set_kernel_info(new_input->kernel_info_ptr()); | |||||
| auto inplace_assign = CreateInplaceAssignNodeAndCorrectReturn(sub_graph, parameter); | |||||
| // Replace atomic ReduceSum's user with atomic clean output, and add depend op after inplaceassign to avoid | |||||
| // elimination. | |||||
| std::vector<std::pair<AnfNodePtr, int>> reduce_user_nodes = FindInnerCNodeUsers(stitch_node_, atomic_add_node_); | |||||
| bool connected = false; | |||||
| for (const auto &[user_node, index] : reduce_user_nodes) { | |||||
| auto user_cnode = user_node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(user_cnode); | |||||
| user_cnode->set_input(index, parameter); | |||||
| if (!connected) { | |||||
| std::vector<std::pair<AnfNodePtr, int>> user_user = FindInnerCNodeUsers(stitch_node_, user_cnode); | |||||
| if (!user_user.empty()) { | |||||
| auto pair = user_user[0]; | |||||
| AddDepend(sub_graph, user_cnode, inplace_assign, pair.first, pair.second); | |||||
| } | |||||
| connected = true; | |||||
| } | |||||
| CorrectKernelBuildInfo(composite_node, new_input); | |||||
| } | |||||
| auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); | |||||
| auto new_graph_name = ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add"); | |||||
| sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name)); | |||||
| MS_LOG(INFO) << "Convert " << old_graph_name << " to atomic add graph " << new_graph_name; | |||||
| } | |||||
| std::vector<std::pair<AnfNodePtr, int>> StitchAtomicCleanInsertter::FindInnerCNodeUsers(const AnfNodePtr &inner_node, | |||||
| const CNodePtr &target) { | |||||
| auto node = inner_node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); | |||||
| auto mng_sub = sub_graph->manager(); | |||||
| if (mng_sub == nullptr) { | |||||
| mng_sub = Manage(sub_graph, false); | |||||
| sub_graph->set_manager(mng_sub); | |||||
| } | |||||
| std::vector<std::pair<AnfNodePtr, int>> inner_user_nodes; | |||||
| auto users = mng_sub->node_users()[target]; | |||||
| std::transform(users.cbegin(), users.cend(), std::back_inserter(inner_user_nodes), | |||||
| [](const std::pair<AnfNodePtr, int> &pair) { return pair; }); | |||||
| return inner_user_nodes; | |||||
| } | |||||
| bool StitchAtomicCleanInsertter::IsStitchWithAtomic(const AnfNodePtr &anf_node) { | |||||
| if (!AnfAlgo::IsGraphKernel(anf_node)) return false; | |||||
| auto node = anf_node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); | |||||
| AnfNodePtrList kernel_nodes; | |||||
| kernel::GetValidKernelNodes(sub_graph, &kernel_nodes); | |||||
| for (auto &n : kernel_nodes) { | |||||
| if (AnfAlgo::HasNodeAttr(kAttrStitch, n->cast<CNodePtr>()) && | |||||
| AnfAlgo::GetNodeAttr<std::string>(n, kAttrStitch) == "atomic" && IsPrimitiveCNode(n, prim::kPrimReduceSum)) { | |||||
| MS_LOG(INFO) << "GOT STITCH WITH ATOMIC!!!"; | |||||
| atomic_add_node_ = n->cast<CNodePtr>(); | |||||
| stitch_node_ = anf_node; | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool StitchAtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) { | |||||
| auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph); | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| auto mng = kernel_graph->manager(); | |||||
| if (mng == nullptr) { | |||||
| mng = Manage(kernel_graph, true); | |||||
| kernel_graph->set_manager(mng); | |||||
| } | |||||
| bool changed = false; | |||||
| auto topo_nodes = TopoSort(kernel_graph->get_return()); | |||||
| for (const auto &node : topo_nodes) { | |||||
| // if stitch attr exists, add atomic clean op depends on the attr | |||||
| if (IsStitchWithAtomic(node)) { | |||||
| InsertAtomicClean(kernel_graph, node, mng); | |||||
| changed = true; | |||||
| } | |||||
| } | |||||
| if (changed) { | |||||
| mng->RemoveRoots(); | |||||
| mng->KeepRoots({func_graph}); | |||||
| } | |||||
| return changed; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,50 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_STITCH_ATOMIC_CLEAN_GPU_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_STITCH_ATOMIC_CLEAN_GPU_H_ | |||||
| #include <memory> | |||||
| #include <tuple> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| #include "backend/optimizer/graph_kernel/add_atomic_clean_gpu.h" | |||||
| #include "backend/session/kernel_graph.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class StitchAtomicCleanInsertter : public AtomicCleanInsertter { | |||||
| public: | |||||
| StitchAtomicCleanInsertter() : AtomicCleanInsertter("stitch_atomic_clean") {} | |||||
| ~StitchAtomicCleanInsertter() override = default; | |||||
| bool Run(const FuncGraphPtr &func_graph) override; | |||||
| private: | |||||
| void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input); | |||||
| CNodePtr CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter); | |||||
| std::vector<std::pair<AnfNodePtr, int>> FindInnerCNodeUsers(const AnfNodePtr &inner_node, const CNodePtr &target); | |||||
| void ProcessOriginCNode(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node, | |||||
| const AnfNodePtr &new_input, const FuncGraphManagerPtr &mng); | |||||
| bool IsStitchWithAtomic(const AnfNodePtr &anf_node); | |||||
| AnfNodePtr stitch_node_{nullptr}; | |||||
| }; | |||||
| using StitchAtomicCleanInsertterPtr = std::shared_ptr<StitchAtomicCleanInsertter>; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_STITCH_ATOMIC_CLEAN_GPU_H_ | |||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -717,6 +717,7 @@ std::unordered_set<PrimitivePtr> GetExpandOps() { | |||||
| prim::kPrimMinimumGrad, | prim::kPrimMinimumGrad, | ||||
| prim::kPrimGkDropout, | prim::kPrimGkDropout, | ||||
| prim::kPrimDropoutGrad, | prim::kPrimDropoutGrad, | ||||
| prim::kPrimSoftMax, | |||||
| #endif | #endif | ||||
| }; | }; | ||||
| return expand_ops; | return expand_ops; | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -42,6 +42,7 @@ | |||||
| #include "backend/optimizer/gpu/add_relu_v2_fusion.h" | #include "backend/optimizer/gpu/add_relu_v2_fusion.h" | ||||
| #include "backend/optimizer/gpu/add_relu_grad_v2_fusion.h" | #include "backend/optimizer/gpu/add_relu_grad_v2_fusion.h" | ||||
| #include "backend/optimizer/graph_kernel/add_atomic_clean_gpu.h" | #include "backend/optimizer/graph_kernel/add_atomic_clean_gpu.h" | ||||
| #include "backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h" | |||||
| #include "backend/optimizer/graph_kernel/arithmetic_simplify.h" | #include "backend/optimizer/graph_kernel/arithmetic_simplify.h" | ||||
| #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" | #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" | ||||
| #include "backend/optimizer/graph_kernel/clean_all_in_once.h" | #include "backend/optimizer/graph_kernel/clean_all_in_once.h" | ||||
| @@ -201,6 +202,7 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_ | |||||
| // will be exposed, use GetitemTuple Pass to delete them. | // will be exposed, use GetitemTuple Pass to delete them. | ||||
| pm->AddPass(std::make_shared<opt::GetitemTuple>()); | pm->AddPass(std::make_shared<opt::GetitemTuple>()); | ||||
| pm->AddPass(std::make_shared<opt::AtomicCleanInsertter>()); | pm->AddPass(std::make_shared<opt::AtomicCleanInsertter>()); | ||||
| pm->AddPass(std::make_shared<opt::StitchAtomicCleanInsertter>()); | |||||
| pm->AddPass(std::make_shared<opt::DependFormater>()); // Prevent fake loop in parallel fusion. | pm->AddPass(std::make_shared<opt::DependFormater>()); // Prevent fake loop in parallel fusion. | ||||
| pm->AddPass(std::make_shared<opt::ParallelOpFusion>(kGPUDevice, opt::ParallelConfig(7))); | pm->AddPass(std::make_shared<opt::ParallelOpFusion>(kGPUDevice, opt::ParallelConfig(7))); | ||||
| pm->AddPass(std::make_shared<opt::BindValueToGraph>()); | pm->AddPass(std::make_shared<opt::BindValueToGraph>()); | ||||
| @@ -384,6 +384,7 @@ constexpr auto kAttrIsGrad = "is_grad"; | |||||
| constexpr auto kAttrRecompute = "recompute"; | constexpr auto kAttrRecompute = "recompute"; | ||||
| constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute"; | constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute"; | ||||
| constexpr auto kAttrParallelDimInfo = "parallel_dim_info"; | constexpr auto kAttrParallelDimInfo = "parallel_dim_info"; | ||||
| constexpr auto kAttrStitch = "stitch"; | |||||
| // attr value | // attr value | ||||
| constexpr auto kValueTargetSwitch = "target_switch"; | constexpr auto kValueTargetSwitch = "target_switch"; | ||||
| @@ -133,7 +133,7 @@ inline const PrimitivePtr kPrimRange = std::make_shared<Primitive>("Range"); | |||||
| // NN | // NN | ||||
| inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | ||||
| inline const PrimitivePtr kPrimSoftMax = std::make_shared<Primitive>("SoftMax"); | |||||
| inline const PrimitivePtr kPrimSoftMax = std::make_shared<Primitive>("Softmax"); | |||||
| inline const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoftmax"); | inline const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoftmax"); | ||||
| inline const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad"); | inline const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad"); | ||||
| inline const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh"); | inline const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh"); | ||||