| @@ -13,6 +13,7 @@ | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """Cost model splitter""" | |||
| import os | |||
| from functools import reduce | |||
| from .model import PrimLib, Graph, Tensor | |||
| @@ -23,12 +24,19 @@ class GraphSplitByPattern: | |||
| MODE_BASIC = 1 | |||
| MODE_COMPOSITE = 2 | |||
| class StitchInfo: | |||
| """StitchInfo""" | |||
| def __init__(self): | |||
| self.stitch_ops = set() | |||
| self.stitch_atomic_ops = set() | |||
| def __init__(self, init_op, is_output): | |||
| self.pattern = PrimLib.iter_type(init_op) | |||
| self.ops = [init_op] | |||
| self.in_relations = dict() # {area1: relation1, area2: relation2, ...} | |||
| self.out_relations = dict() # {area1: relation1, area2: relation2, ...} | |||
| self.mode = None | |||
| self.stitch_info = self.StitchInfo() | |||
| self.is_output = is_output | |||
| self.output_excluded = set() | |||
| if self.pattern == PrimLib.REDUCE: | |||
| @@ -69,6 +77,12 @@ class GraphSplitByPattern: | |||
| for input_area, r in self.in_relations.items(): | |||
| input_area.out_relations[self] = r | |||
| def update_stitch_info(self, stitch_info): | |||
| if stitch_info.stitch_ops: | |||
| self.stitch_info.stitch_ops.update(stitch_info.stitch_ops) | |||
| if stitch_info.stitch_atomic_ops: | |||
| self.stitch_info.stitch_atomic_ops.update(stitch_info.stitch_atomic_ops) | |||
| def fuse(self, area): | |||
| """Fuse `area` to `self`""" | |||
| def _update_relation(relations, a, r): | |||
| @@ -107,6 +121,7 @@ class GraphSplitByPattern: | |||
| self.is_output = True | |||
| if area.output_excluded: | |||
| self.output_excluded.update(area.output_excluded) | |||
| self.update_stitch_info(area.stitch_info) | |||
| def check_circle(self, to): | |||
| """Check circle. It returns false if circle exists""" | |||
| @@ -181,10 +196,25 @@ class GraphSplitByPattern: | |||
| graphmodes = [] | |||
| for i, area in enumerate(self.areas): | |||
| area.ops.sort(key=lambda op: ids[op]) | |||
| subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops)) | |||
| subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops, area.stitch_info)) | |||
| graphmodes.append("basic" if area.mode == self.Area.MODE_BASIC else "composite") | |||
| return subgraphs, graphmodes | |||
| def dump_subgraphs(self, subgraphs): | |||
| """Dump subgraphs""" | |||
| if os.environ.get("ENABLE_SUBGRAPHS", "off") == "on": | |||
| subgraphs_str = "subgraphs:\nlen: " + str(len(subgraphs)) + "\n" | |||
| for i, sub in enumerate(subgraphs): | |||
| subgraphs_str += str("============") + str(i) + "\n" | |||
| subgraphs_str += str(sub) | |||
| dirname = 'subgraphs' | |||
| if not os.path.exists(dirname): | |||
| os.makedirs(dirname) | |||
| graphname = self.graph.name | |||
| filename = dirname + '/' + graphname + '.log' | |||
| with open(filename, 'w') as f: | |||
| f.write(subgraphs_str) | |||
| def split(self): | |||
| """Split graph by pattern""" | |||
| self.do_split() | |||
| @@ -192,6 +222,7 @@ class GraphSplitByPattern: | |||
| # Note: after this function, the input output relation is not maintained. | |||
| self.split_output_reshapes() | |||
| subgraphs, graphmodes = self.to_subgraphs() | |||
| self.dump_subgraphs(subgraphs) | |||
| return subgraphs, graphmodes | |||
| def split_output_reshapes(self): | |||
| @@ -362,15 +393,25 @@ class GraphSplitGpu(GraphSplitByPattern): | |||
| return reduce_size >= 1024 | |||
| return True | |||
| def _reduce_nums(ops): | |||
| count = 0 | |||
| for op in ops: | |||
| if op.prim.startswith('Reduce'): | |||
| count += 1 | |||
| return count | |||
| def _reduce_output(dom): | |||
| if dom.pattern != PrimLib.REDUCE: | |||
| return None | |||
| if _reduce_nums(dom.ops) > 1: | |||
| return None | |||
| if _is_atomic_add_available(dom): | |||
| return None | |||
| is_all_reduce = _tensor_size(dom.ops[0].output) == 1 | |||
| # excluded large size all reduce | |||
| if is_all_reduce and _tensor_size(dom.ops[0].inputs[0]) > 1024 * 12: | |||
| return None | |||
| fused = [] | |||
| for a, r in dom.out_relations.items(): | |||
| if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \ | |||
| @@ -378,6 +419,24 @@ class GraphSplitGpu(GraphSplitByPattern): | |||
| fused.append(a) | |||
| return fused, False | |||
| def _reduce_stitch(dom): | |||
| if dom.pattern != PrimLib.REDUCE: | |||
| return None | |||
| if _tensor_size(dom.ops[0].output) == 1: | |||
| return None | |||
| if _tensor_size(dom.ops[0].inputs[0]) < 1024 * 12: | |||
| return None | |||
| fused = [] | |||
| for a, r in dom.out_relations.items(): | |||
| if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_circle(a): | |||
| if _reduce_nums(a.ops) < 2: | |||
| # softmax | |||
| if len(a.ops) > 4 and len(a.ops[0].inputs[0].shape) == 4: | |||
| dom.stitch_info.stitch_ops.add(dom.ops[0].output.name) | |||
| fused.append(a) | |||
| return fused, False | |||
| def _transpose(dom): | |||
| if len(dom.ops) != 1 or dom.ops[0].prim != "Transpose": | |||
| return None | |||
| @@ -398,6 +457,7 @@ class GraphSplitGpu(GraphSplitByPattern): | |||
| changed = self.fuse(_broadcast_width) or changed | |||
| if use_poly_reduce: | |||
| changed = self.fuse(_reduce_output) or changed | |||
| changed = self.fuse(_reduce_stitch) or changed | |||
| self.fuse(_transpose) | |||
| class GraphSplitAscend(GraphSplitByPattern): | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -310,11 +310,12 @@ class Operator: | |||
| class Graph: | |||
| """Graph""" | |||
| def __init__(self, name, ops): | |||
| def __init__(self, name, ops, stitch_info=None): | |||
| self.name = name | |||
| self.ops = ops # in topo order, can not use set | |||
| self.inputs = [] | |||
| self.outputs = [] | |||
| self.stitch_info = stitch_info | |||
| def set_processor(self, processor): | |||
| """Set processor""" | |||
| @@ -372,6 +373,12 @@ class Graph: | |||
| out_str = ', '.join([repr(t) for t in outputs]) | |||
| lines = [] | |||
| lines.append("%s(%s) -> %s {" % (self.name, para_str, out_str)) | |||
| if self.stitch_info: | |||
| if self.stitch_info.stitch_ops: | |||
| lines.append(' stitch -> ' + str(self.stitch_info.stitch_ops)) | |||
| if self.stitch_info.stitch_atomic_ops: | |||
| lines.append(' stitch_atomic_ops-> ' + str(self.stitch_info.stitch_atomic_ops)) | |||
| for op in self.ops: | |||
| lines.append(' ' + str(op)) | |||
| lines.append('}') | |||
| @@ -405,12 +412,20 @@ class Graph: | |||
| in_desc.append([{'data_type': t.dtype, 'value': t.value, 'name': '', 'shape': t.shape, | |||
| 'tensor_name': t.name, 'format': t.data_format}]) | |||
| out_desc = [{'data_type': op.output.dtype, 'name': '', 'shape': op.output.shape, | |||
| 'tensor_name': op.output.name, 'format': t.data_format}] | |||
| 'tensor_name': op.output.name, 'format': op.output.data_format}] | |||
| op_desc.append({'attr': attrs, 'impl_path': '', | |||
| 'input_desc': in_desc, 'name': op.prim, 'output_desc': out_desc}) | |||
| graph_desc = {'composite': True, 'composite_graph': '', 'id': 0, | |||
| 'input_desc': input_desc, 'op': self.name, 'op_desc': op_desc, 'output_desc': output_desc, | |||
| 'platform': 'AKG', 'process': self.processor} | |||
| if self.stitch_info and self.stitch_info.stitch_ops: | |||
| buffer_stitch = {'stitch_op': list(self.stitch_info.stitch_ops)} | |||
| if self.stitch_info.stitch_atomic_ops: | |||
| buffer_stitch['stitch_atomic_op'] = list(self.stitch_info.stitch_atomic_ops) | |||
| graph_desc['buffer_stitch'] = buffer_stitch | |||
| return graph_desc | |||
| @@ -313,6 +313,14 @@ class CompositeGraph: | |||
| self.graph = builder.get()[0] | |||
| self.desc = desc | |||
| def add_stitch_info(self, subgraph, desc): | |||
| if subgraph.stitch_info and subgraph.stitch_info.stitch_ops: | |||
| buffer_stitch = {'stitch_op': list(subgraph.stitch_info.stitch_ops)} | |||
| if subgraph.stitch_info.stitch_atomic_ops: | |||
| buffer_stitch['stitch_atomic_op'] = list(subgraph.stitch_info.stitch_atomic_ops) | |||
| desc['buffer_stitch'] = buffer_stitch | |||
| return desc | |||
| def dump(self, subgraph): | |||
| """Dump Graph to json""" | |||
| desc = {} | |||
| @@ -368,6 +376,8 @@ class CompositeGraph: | |||
| desc[key] = subgraph.name | |||
| else: | |||
| desc[key] = self.desc[key] | |||
| desc = self.add_stitch_info(subgraph, desc) | |||
| return desc | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -414,6 +414,35 @@ FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const std::string &kernel_js | |||
| return DecodeFusedNodes(kernel_json); | |||
| } | |||
| StitchInfo AkgKernelJsonDecoder::GetStitchInfo(const nlohmann::json &kernel_json) { | |||
| StitchInfo info; | |||
| if (kernel_json.find(kJsonKeyBufferStitch) != kernel_json.end()) { | |||
| nlohmann::json buffer_stitch = kernel_json[kJsonKeyBufferStitch]; | |||
| if (buffer_stitch.find(kJsonKeyStitchOp) != buffer_stitch.end()) { | |||
| std::vector<std::string> stitch_ops = buffer_stitch[kJsonKeyStitchOp]; | |||
| info.stitch_ops = stitch_ops; | |||
| } | |||
| if (buffer_stitch.find(kJsonKeyStitchAtomicOp) != buffer_stitch.end()) { | |||
| std::vector<std::string> stitch_atomic_ops = buffer_stitch[kJsonKeyStitchAtomicOp]; | |||
| info.stitch_atomic_ops = stitch_atomic_ops; | |||
| } | |||
| } | |||
| return info; | |||
| } | |||
| void AkgKernelJsonDecoder::SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node) { | |||
| std::vector<nlohmann::json> output_descs = op_desc[kJsonKeyOutputDesc]; | |||
| if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) return; | |||
| std::string tensor_name = output_descs[0][kJsonKeyTensorName]; | |||
| if (std::find(info.stitch_ops.begin(), info.stitch_ops.end(), tensor_name) != info.stitch_ops.end()) { | |||
| AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("common"), node); | |||
| } | |||
| if (std::find(info.stitch_atomic_ops.begin(), info.stitch_atomic_ops.end(), tensor_name) != | |||
| info.stitch_atomic_ops.end()) { | |||
| AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("atomic"), node); | |||
| } | |||
| } | |||
| bool AkgKernelJsonDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json, | |||
| const std::map<std::string, AnfNodePtr> &address_node_map, | |||
| AnfNodePtrList *res_graphs) { | |||
| @@ -425,6 +454,7 @@ bool AkgKernelJsonDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json, | |||
| MS_LOG(ERROR) << "Error decode, no cnodes for graph." << kernel_json; | |||
| return false; | |||
| } | |||
| StitchInfo info = GetStitchInfo(kernel_json); | |||
| for (const auto &op_desc : op_node_descs) { | |||
| if (op_desc.find(kJsonKeyPtrAddress) == op_desc.end() || op_desc[kJsonKeyPtrAddress].is_null()) { | |||
| MS_LOG(ERROR) << "Decode failed, key: " << kJsonKeyPtrAddress << " not found in: " << op_desc; | |||
| @@ -436,7 +466,9 @@ bool AkgKernelJsonDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json, | |||
| MS_LOG(ERROR) << "Decode failed, ptr_address not found in map."; | |||
| return false; | |||
| } | |||
| res_graphs->push_back(address_node_map.at(ptr_address)); | |||
| auto node = address_node_map.at(ptr_address)->cast<CNodePtr>(); | |||
| SetStitchAttr(op_desc, info, node); | |||
| res_graphs->push_back(node); | |||
| } | |||
| MS_LOG(DEBUG) << "decode cnodes success, size: " << res_graphs->size(); | |||
| return true; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -26,6 +26,10 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| struct StitchInfo { | |||
| std::vector<std::string> stitch_ops; | |||
| std::vector<std::string> stitch_atomic_ops; | |||
| }; | |||
| class AkgKernelJsonDecoder { | |||
| public: | |||
| AkgKernelJsonDecoder() { nodes_map_.clear(); } | |||
| @@ -40,6 +44,8 @@ class AkgKernelJsonDecoder { | |||
| ParameterPtr DecodeParameter(const nlohmann::json ¶meter_json, const FuncGraphPtr &func_graph); | |||
| CNodePtr DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, const std::string &processor); | |||
| AnfNodePtr DecodeOutput(const std::vector<nlohmann::json> &output_descs, const FuncGraphPtr &func_graph); | |||
| StitchInfo GetStitchInfo(const nlohmann::json &kernel_json); | |||
| void SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node); | |||
| std::map<std::string, AnfNodePtr> nodes_map_; | |||
| }; | |||
| } // namespace kernel | |||
| @@ -580,6 +580,31 @@ void AkgKernelJsonGenerator::AddParalleFusionJsonInfo(const std::string &process | |||
| (*kernel_json)[kJsonKeyParallelFusion] = parallel_fusion_json; | |||
| } | |||
| void AkgKernelJsonGenerator::GenStitchJson(const std::vector<AnfNodePtr> &anf_nodes, | |||
| std::map<AnfNodePtr, nlohmann::json> *node_json_map, | |||
| nlohmann::json *kernel_json) { | |||
| std::vector<std::string> stitchs; | |||
| for (auto const &anf_node : anf_nodes) { | |||
| if (AnfAlgo::HasNodeAttr(kAttrStitch, anf_node->cast<CNodePtr>()) && | |||
| AnfAlgo::GetNodeAttr<std::string>(anf_node, kAttrStitch) == "common") { | |||
| auto name = GetTensorName((*node_json_map)[anf_node], kJsonKeyOutputDesc, {0, 0}); | |||
| if (std::find(stitchs.begin(), stitchs.end(), name) == stitchs.end()) { | |||
| stitchs.emplace_back(name); | |||
| } | |||
| } | |||
| } | |||
| if (!stitchs.empty()) { | |||
| std::vector<nlohmann::json> v; | |||
| for (auto &s : stitchs) { | |||
| std::vector<std::string> t; | |||
| t.emplace_back(s); | |||
| v.emplace_back(t); | |||
| } | |||
| nlohmann::json stitch_json; | |||
| stitch_json[kJsonKeyStitchOp] = v; | |||
| (*kernel_json)[kJsonKeyBufferStitch] = stitch_json; | |||
| } | |||
| } | |||
| bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes, | |||
| const std::vector<AnfNodePtr> &input_list, | |||
| const std::vector<AnfNodePtr> &output_list, nlohmann::json *kernel_json) { | |||
| @@ -637,6 +662,8 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf | |||
| (*kernel_json)[kJsonKeyComposite] = true; | |||
| (*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString() + "." + fg->debug_info()->get_id(); | |||
| GenStitchJson(anf_nodes, &node_json_map, kernel_json); | |||
| if (!GetIOSize(*kernel_json, &input_size_list_, &output_size_list_)) { | |||
| MS_LOG(ERROR) << "Cal mem size failed."; | |||
| return false; | |||
| @@ -54,6 +54,9 @@ constexpr auto kJsonKeyParallelFusion = "parallel_fusion"; | |||
| constexpr auto kJsonKeyFusionType = "fusion_type"; | |||
| constexpr auto kJsonKeySubGraph = "sub_graph"; | |||
| constexpr auto kJsonKeyCoreNum = "core_num"; | |||
| constexpr auto kJsonKeyBufferStitch = "buffer_stitch"; | |||
| constexpr auto kJsonKeyStitchOp = "stitch_op"; | |||
| constexpr auto kJsonKeyStitchAtomicOp = "stitch_atomic_op"; | |||
| constexpr auto kAttrInputNames = "input_names"; | |||
| @@ -98,6 +101,8 @@ class AkgKernelJsonGenerator { | |||
| void GetAttrJson(const AnfNodePtr &anf_node, const std::vector<int> &dyn_input_sizes, const OpAttrPtr &op_attr, | |||
| nlohmann::json *attr_json, const ValuePtr &attr_value); | |||
| bool CreateAttrDescJson(const AnfNodePtr &anf_node, const OpInfoPtr &op_info, nlohmann::json *attrs_json); | |||
| void GenStitchJson(const std::vector<AnfNodePtr> &anf_nodes, std::map<AnfNodePtr, nlohmann::json> *node_json_map, | |||
| nlohmann::json *kernel_json); | |||
| bool GetIOSize(const nlohmann::json &node_json, std::vector<size_t> *input_size, std::vector<size_t> *output_size); | |||
| bool GenSingleJsons(const std::vector<AnfNodePtr> &anf_nodes, std::map<AnfNodePtr, nlohmann::json> *node_json_map); | |||
| void UpdateTensorName(const std::vector<AnfNodePtr> &anf_nodes, std::map<AnfNodePtr, nlohmann::json> *node_json_map); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -21,6 +21,7 @@ | |||
| #include <tuple> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <string> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "backend/session/kernel_graph.h" | |||
| @@ -28,21 +29,24 @@ namespace mindspore { | |||
| namespace opt { | |||
| class AtomicCleanInsertter : public Pass { | |||
| public: | |||
| AtomicCleanInsertter() : Pass("atomic_clean") {} | |||
| explicit AtomicCleanInsertter(const std::string &name = "atomic_clean") : Pass(name) {} | |||
| ~AtomicCleanInsertter() override = default; | |||
| bool Run(const FuncGraphPtr &func_graph) override; | |||
| virtual bool Run(const FuncGraphPtr &func_graph); | |||
| private: | |||
| void ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input, | |||
| const FuncGraphManagerPtr &mng); | |||
| bool CanActivateAtomicAdd(const AnfNodePtr &anf_node); | |||
| protected: | |||
| virtual void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input); | |||
| virtual void ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input, | |||
| const FuncGraphManagerPtr &mng); | |||
| void InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node, const FuncGraphManagerPtr &mng); | |||
| void AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, const AnfNodePtr &composite_node, | |||
| const AnfNodePtr &user_node, int index); | |||
| void CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter); | |||
| CNodePtr atomic_add_node_{nullptr}; | |||
| private: | |||
| bool CanActivateAtomicAdd(const AnfNodePtr &anf_node); | |||
| void CorrectAbstract(const AnfNodePtr &composite_node); | |||
| void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input); | |||
| CNodePtr CreateAtomicCleanCompositeNode(const KernelGraphPtr &main_graph, TypeId dst_type); | |||
| void CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter); | |||
| void ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node, | |||
| const AnfNodePtr &broadcast_to_node, const FuncGraphManagerPtr &mng); | |||
| std::tuple<AnfNodePtr, AnfNodePtr, int> FindPatronNode(const KernelGraphPtr &main_graph); | |||
| @@ -55,7 +59,6 @@ class AtomicCleanInsertter : public Pass { | |||
| bool IsExistStructuralObstacle(const KernelGraphPtr &main_graph, const AnfNodePtr &node, | |||
| const FuncGraphManagerPtr &mng); | |||
| CNodePtr atomic_add_node_{nullptr}; | |||
| size_t reduce_real_output_index_{0}; | |||
| size_t real_output_num_{0}; | |||
| std::vector<std::pair<AnfNodePtr, AnfNodePtr>> to_process_order_; | |||
| @@ -0,0 +1,200 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h" | |||
| #include <algorithm> | |||
| #include <functional> | |||
| #include <list> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <set> | |||
| #include <stack> | |||
| #include <string> | |||
| #include <tuple> | |||
| #include <vector> | |||
| #include "base/core_ops.h" | |||
| #include "ir/tensor.h" | |||
| #include "utils/utils.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "backend/kernel_compiler/kernel.h" | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "backend/session/kernel_graph.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| void StitchAtomicCleanInsertter::CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input) { | |||
| // Change kernel build info. | |||
| auto kernel_info = static_cast<device::KernelInfo *>(composite_node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| const auto &origin_kernel_build_info = kernel_info->GetMutableSelectKernelBuildInfo(); | |||
| auto origin_inputs_format = origin_kernel_build_info->GetAllInputFormats(); | |||
| auto origin_outputs_format = origin_kernel_build_info->GetAllOutputFormats(); | |||
| auto origin_inputs_type = origin_kernel_build_info->GetAllInputDeviceTypes(); | |||
| auto origin_outputs_type = origin_kernel_build_info->GetAllOutputDeviceTypes(); | |||
| auto origin_processor = origin_kernel_build_info->processor(); | |||
| std::vector<std::string> &new_inputs_format = origin_inputs_format; | |||
| std::vector<TypeId> &new_inputs_type = origin_inputs_type; | |||
| std::vector<std::string> new_outputs_format; | |||
| std::vector<TypeId> new_outputs_type; | |||
| for (size_t i = 0; i < origin_outputs_format.size(); ++i) { | |||
| new_outputs_format.push_back(origin_outputs_format[i]); | |||
| new_outputs_type.push_back(origin_outputs_type[i]); | |||
| } | |||
| auto kernel_with_index = AnfAlgo::VisitKernel(new_input, 0); | |||
| new_inputs_format.push_back(AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second)); | |||
| new_inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second)); | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder new_info_builder; | |||
| new_info_builder.SetInputsFormat(new_inputs_format); | |||
| new_info_builder.SetInputsDeviceType(new_inputs_type); | |||
| new_info_builder.SetOutputsFormat(new_outputs_format); | |||
| new_info_builder.SetOutputsDeviceType(new_outputs_type); | |||
| new_info_builder.SetProcessor(origin_processor); | |||
| new_info_builder.SetKernelType(KernelType::AKG_KERNEL); | |||
| new_info_builder.SetFusionType(kernel::FusionType::OPAQUE); | |||
| auto new_selected_info = new_info_builder.Build(); | |||
| AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get()); | |||
| } | |||
| CNodePtr StitchAtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, | |||
| const AnfNodePtr &new_parameter) { | |||
| // add inplaceassign | |||
| AnfNodePtr out_node = atomic_add_node_; // Use result data itself, and set attr "fake_out" true. | |||
| auto inplace_assign_node = | |||
| CreateCNode({NewValueNode(std::make_shared<Primitive>("InplaceAssign")), new_parameter, atomic_add_node_, out_node}, | |||
| sub_graph, {.format = GetFormat(out_node), .shape = GetShape(out_node), .type = GetType(out_node)}); | |||
| AnfAlgo::SetNodeAttr("fake_output", MakeValue(true), inplace_assign_node); | |||
| AnfAlgo::EraseNodeAttr(kAttrStitch, atomic_add_node_); | |||
| AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("common"), inplace_assign_node); | |||
| return inplace_assign_node; | |||
| } | |||
| void StitchAtomicCleanInsertter::ProcessOriginCNode(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node, | |||
| const AnfNodePtr &new_input, const FuncGraphManagerPtr &mng) { | |||
| auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(composite_node); | |||
| auto mng_sub = sub_graph->manager(); | |||
| if (mng_sub == nullptr) { | |||
| mng_sub = Manage(sub_graph, false); | |||
| sub_graph->set_manager(mng_sub); | |||
| } | |||
| // add input | |||
| auto inputs = composite_node->cast<CNodePtr>()->inputs(); | |||
| inputs.push_back(new_input); | |||
| composite_node->cast<CNodePtr>()->set_inputs(inputs); | |||
| // add parameter | |||
| auto parameter = sub_graph->add_parameter(); | |||
| parameter->set_abstract(new_input->abstract()); | |||
| parameter->set_kernel_info(new_input->kernel_info_ptr()); | |||
| auto inplace_assign = CreateInplaceAssignNodeAndCorrectReturn(sub_graph, parameter); | |||
| // Replace atomic ReduceSum's user with atomic clean output, and add depend op after inplaceassign to avoid | |||
| // elimination. | |||
| std::vector<std::pair<AnfNodePtr, int>> reduce_user_nodes = FindInnerCNodeUsers(stitch_node_, atomic_add_node_); | |||
| bool connected = false; | |||
| for (const auto &[user_node, index] : reduce_user_nodes) { | |||
| auto user_cnode = user_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(user_cnode); | |||
| user_cnode->set_input(index, parameter); | |||
| if (!connected) { | |||
| std::vector<std::pair<AnfNodePtr, int>> user_user = FindInnerCNodeUsers(stitch_node_, user_cnode); | |||
| if (!user_user.empty()) { | |||
| auto pair = user_user[0]; | |||
| AddDepend(sub_graph, user_cnode, inplace_assign, pair.first, pair.second); | |||
| } | |||
| connected = true; | |||
| } | |||
| CorrectKernelBuildInfo(composite_node, new_input); | |||
| } | |||
| auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); | |||
| auto new_graph_name = ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add"); | |||
| sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name)); | |||
| MS_LOG(INFO) << "Convert " << old_graph_name << " to atomic add graph " << new_graph_name; | |||
| } | |||
| std::vector<std::pair<AnfNodePtr, int>> StitchAtomicCleanInsertter::FindInnerCNodeUsers(const AnfNodePtr &inner_node, | |||
| const CNodePtr &target) { | |||
| auto node = inner_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); | |||
| auto mng_sub = sub_graph->manager(); | |||
| if (mng_sub == nullptr) { | |||
| mng_sub = Manage(sub_graph, false); | |||
| sub_graph->set_manager(mng_sub); | |||
| } | |||
| std::vector<std::pair<AnfNodePtr, int>> inner_user_nodes; | |||
| auto users = mng_sub->node_users()[target]; | |||
| std::transform(users.cbegin(), users.cend(), std::back_inserter(inner_user_nodes), | |||
| [](const std::pair<AnfNodePtr, int> &pair) { return pair; }); | |||
| return inner_user_nodes; | |||
| } | |||
| bool StitchAtomicCleanInsertter::IsStitchWithAtomic(const AnfNodePtr &anf_node) { | |||
| if (!AnfAlgo::IsGraphKernel(anf_node)) return false; | |||
| auto node = anf_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); | |||
| AnfNodePtrList kernel_nodes; | |||
| kernel::GetValidKernelNodes(sub_graph, &kernel_nodes); | |||
| for (auto &n : kernel_nodes) { | |||
| if (AnfAlgo::HasNodeAttr(kAttrStitch, n->cast<CNodePtr>()) && | |||
| AnfAlgo::GetNodeAttr<std::string>(n, kAttrStitch) == "atomic" && IsPrimitiveCNode(n, prim::kPrimReduceSum)) { | |||
| MS_LOG(INFO) << "GOT STITCH WITH ATOMIC!!!"; | |||
| atomic_add_node_ = n->cast<CNodePtr>(); | |||
| stitch_node_ = anf_node; | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| bool StitchAtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) { | |||
| auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| auto mng = kernel_graph->manager(); | |||
| if (mng == nullptr) { | |||
| mng = Manage(kernel_graph, true); | |||
| kernel_graph->set_manager(mng); | |||
| } | |||
| bool changed = false; | |||
| auto topo_nodes = TopoSort(kernel_graph->get_return()); | |||
| for (const auto &node : topo_nodes) { | |||
| // if stitch attr exists, add atomic clean op depends on the attr | |||
| if (IsStitchWithAtomic(node)) { | |||
| InsertAtomicClean(kernel_graph, node, mng); | |||
| changed = true; | |||
| } | |||
| } | |||
| if (changed) { | |||
| mng->RemoveRoots(); | |||
| mng->KeepRoots({func_graph}); | |||
| } | |||
| return changed; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,50 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_STITCH_ATOMIC_CLEAN_GPU_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_STITCH_ATOMIC_CLEAN_GPU_H_ | |||
| #include <memory> | |||
| #include <tuple> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "backend/optimizer/graph_kernel/add_atomic_clean_gpu.h" | |||
| #include "backend/session/kernel_graph.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class StitchAtomicCleanInsertter : public AtomicCleanInsertter { | |||
| public: | |||
| StitchAtomicCleanInsertter() : AtomicCleanInsertter("stitch_atomic_clean") {} | |||
| ~StitchAtomicCleanInsertter() override = default; | |||
| bool Run(const FuncGraphPtr &func_graph) override; | |||
| private: | |||
| void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input); | |||
| CNodePtr CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter); | |||
| std::vector<std::pair<AnfNodePtr, int>> FindInnerCNodeUsers(const AnfNodePtr &inner_node, const CNodePtr &target); | |||
| void ProcessOriginCNode(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node, | |||
| const AnfNodePtr &new_input, const FuncGraphManagerPtr &mng); | |||
| bool IsStitchWithAtomic(const AnfNodePtr &anf_node); | |||
| AnfNodePtr stitch_node_{nullptr}; | |||
| }; | |||
| using StitchAtomicCleanInsertterPtr = std::shared_ptr<StitchAtomicCleanInsertter>; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_STITCH_ATOMIC_CLEAN_GPU_H_ | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -717,6 +717,7 @@ std::unordered_set<PrimitivePtr> GetExpandOps() { | |||
| prim::kPrimMinimumGrad, | |||
| prim::kPrimGkDropout, | |||
| prim::kPrimDropoutGrad, | |||
| prim::kPrimSoftMax, | |||
| #endif | |||
| }; | |||
| return expand_ops; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -42,6 +42,7 @@ | |||
| #include "backend/optimizer/gpu/add_relu_v2_fusion.h" | |||
| #include "backend/optimizer/gpu/add_relu_grad_v2_fusion.h" | |||
| #include "backend/optimizer/graph_kernel/add_atomic_clean_gpu.h" | |||
| #include "backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h" | |||
| #include "backend/optimizer/graph_kernel/arithmetic_simplify.h" | |||
| #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" | |||
| #include "backend/optimizer/graph_kernel/clean_all_in_once.h" | |||
| @@ -201,6 +202,7 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_ | |||
| // will be exposed, use GetitemTuple Pass to delete them. | |||
| pm->AddPass(std::make_shared<opt::GetitemTuple>()); | |||
| pm->AddPass(std::make_shared<opt::AtomicCleanInsertter>()); | |||
| pm->AddPass(std::make_shared<opt::StitchAtomicCleanInsertter>()); | |||
| pm->AddPass(std::make_shared<opt::DependFormater>()); // Prevent fake loop in parallel fusion. | |||
| pm->AddPass(std::make_shared<opt::ParallelOpFusion>(kGPUDevice, opt::ParallelConfig(7))); | |||
| pm->AddPass(std::make_shared<opt::BindValueToGraph>()); | |||
| @@ -384,6 +384,7 @@ constexpr auto kAttrIsGrad = "is_grad"; | |||
| constexpr auto kAttrRecompute = "recompute"; | |||
| constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute"; | |||
| constexpr auto kAttrParallelDimInfo = "parallel_dim_info"; | |||
| constexpr auto kAttrStitch = "stitch"; | |||
| // attr value | |||
| constexpr auto kValueTargetSwitch = "target_switch"; | |||
| @@ -133,7 +133,7 @@ inline const PrimitivePtr kPrimRange = std::make_shared<Primitive>("Range"); | |||
| // NN | |||
| inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | |||
| inline const PrimitivePtr kPrimSoftMax = std::make_shared<Primitive>("SoftMax"); | |||
| inline const PrimitivePtr kPrimSoftMax = std::make_shared<Primitive>("Softmax"); | |||
| inline const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoftmax"); | |||
| inline const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad"); | |||
| inline const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh"); | |||