Merge pull request !7013 from DeshiChen/0928_costmodel_multioutputtags/v1.1.0
| @@ -14,138 +14,221 @@ | |||
| # =========================================================================== | |||
| """Cost model splitter""" | |||
| from .model import PrimLib, Graph | |||
| from .model import PrimLib, Graph, Tensor | |||
| class GraphSplitByPattern: | |||
| """Graph split by pattern""" | |||
| """Graph splitter""" | |||
| class Area: | |||
| """Area""" | |||
| MODE_BASIC = 1 | |||
| MODE_COMPOSITE = 2 | |||
| def __init__(self, init_op): | |||
| self.pattern = PrimLib.iter_type(init_op) | |||
| self.ops = [init_op] | |||
| self.in_relations = dict() # {area1: relation1, area2: relation2, ...} | |||
| self.out_relations = dict() # {area1: relation1, area2: relation2, ...} | |||
| self.mode = self.MODE_BASIC | |||
| def __str__(self): | |||
| return '<' + '-'.join([op.output.name for op in self.ops]) + '>' | |||
| def __repr__(self): | |||
| return str(self) | |||
| def link_input(self, area_map): | |||
| """Link inputs""" | |||
| def get_relation(op, i): | |||
| relation = PrimLib.UNKNOWN | |||
| _, elem_relation = PrimLib.input_relation(op, i) | |||
| for r in elem_relation: | |||
| if r is not None and r > relation: | |||
| relation = r | |||
| return relation | |||
| for i, t in enumerate(self.ops[0].inputs): | |||
| if t.op is not None: | |||
| area, relation = area_map[t.op], get_relation(self.ops[0], i) | |||
| self.in_relations[area] = relation | |||
| def link_output(self): | |||
| """Link outputs""" | |||
| for input_area, r in self.in_relations.items(): | |||
| input_area.out_relations[self] = r | |||
| def fuse(self, area): | |||
| """Fuse `area` to `self`""" | |||
| def _update_relation(relations, a, r): | |||
| relations[a] = max(r, relations[a]) if a in relations else r | |||
| def _update_pattern(): | |||
| self.pattern = max(self.pattern, area.pattern, self.in_relations[area]) | |||
| def _fuse_relation(self_relations, new_relations): | |||
| for a, r in new_relations.items(): | |||
| if a != self: | |||
| _update_relation(self_relations, a, r) | |||
| if area in self_relations: | |||
| self_relations.pop(area) | |||
| def _redirect_relation(rels): | |||
| """Replace `area` with `self` in relations""" | |||
| if area in rels: | |||
| r = rels.pop(area) | |||
| _update_relation(rels, self, r) | |||
| self.ops.extend(area.ops) | |||
| _update_pattern() | |||
| _fuse_relation(self.in_relations, area.in_relations) | |||
| _fuse_relation(self.out_relations, area.out_relations) | |||
| for a, _ in area.in_relations.items(): | |||
| _redirect_relation(a.out_relations) | |||
| for a, _ in area.out_relations.items(): | |||
| _redirect_relation(a.in_relations) | |||
| self.mode = self.MODE_COMPOSITE | |||
| def check_circle(self, to): | |||
| """Check circle. It returns false if circle exists""" | |||
| def _reached(area, to): | |||
| for out, _ in area.out_relations.items(): | |||
| if out == to or _reached(out, to): | |||
| return True | |||
| return False | |||
| for out, _ in self.out_relations.items(): | |||
| if out != to and _reached(out, to): | |||
| return False | |||
| return True | |||
| BORADCAST_FUSE_DEPTH = 3 | |||
| REDUCE_FUSE_DEPTH = 3 | |||
| def __init__(self, graph): | |||
| self.graph = graph | |||
| self.groups = [] | |||
| self.op_group = {} | |||
| for op in self.graph.ops: | |||
| g = [op] | |||
| self.groups.append(g) | |||
| self.op_group[op] = g | |||
| self.ids = {} | |||
| for i, op in enumerate(graph.ops): | |||
| self.ids[op] = i | |||
| self.doms = self.post_dom(graph.ops) | |||
| _, outputs = graph.deduce_parameters() | |||
| self.outputs = set(outputs) | |||
| def post_dom(self, ops): | |||
| """Post dom""" | |||
| doms, i_doms = {}, {} | |||
| for i in range(len(ops) - 1, -1, -1): | |||
| op = ops[i] | |||
| doms[op] = {op} | |||
| i_dom = None | |||
| if op.output.to_ops: | |||
| suc_dom = set(doms[op.output.to_ops[0]]) | |||
| for to in op.output.to_ops[1:]: | |||
| suc_dom.intersection_update(doms[to]) | |||
| doms[op].update(suc_dom) | |||
| for dom in suc_dom: | |||
| if i_dom is None or self.ids[dom] < self.ids[i_dom]: | |||
| i_dom = dom | |||
| i_doms[op] = i_dom | |||
| return i_doms | |||
| def get_pattern(self, op, i): | |||
| """Get pattern""" | |||
| pattern = PrimLib.UNKNOWN | |||
| _, elem_relation = PrimLib.input_relation(op, i) | |||
| for pat in elem_relation: | |||
| if pat and pat > pattern: | |||
| pattern = pat | |||
| return pattern | |||
| def fuse(self, check_fun): | |||
| """Fuse ops""" | |||
| def _get_path(op, dom): | |||
| path_ops, visited = [], set() | |||
| def _get_path_depth(p): | |||
| visited.add(p) | |||
| if self.op_group[p][0] == p: | |||
| path_ops.append(p) | |||
| for to in p.output.to_ops: | |||
| if to != dom and to not in visited: | |||
| _get_path_depth(to) | |||
| _get_path_depth(op) | |||
| return path_ops | |||
| changed = True | |||
| while changed: | |||
| for group in self.groups: | |||
| op = group[0] | |||
| dom = self.doms[op] | |||
| if dom is None or op.output in self.outputs: | |||
| continue | |||
| ops = _get_path(op, dom) | |||
| if check_fun(op, dom, ops): | |||
| dom_group = self.op_group[dom] | |||
| fused = [] | |||
| for fop in ops: | |||
| f_group = self.op_group[fop] | |||
| for p in f_group: | |||
| self.op_group[p] = dom_group | |||
| fused.append(f_group) | |||
| dom_group += f_group | |||
| for g in fused: | |||
| self.groups.remove(g) | |||
| self.areas = [] | |||
| area_map = {} | |||
| for op in graph.ops: | |||
| a = self.Area(op) | |||
| self.areas.append(a) | |||
| area_map[op] = a | |||
| for a in self.areas: | |||
| a.link_input(area_map) | |||
| for a in self.areas: | |||
| a.link_output() | |||
| def fuse(self, selector): | |||
| """Fuse areas""" | |||
| changed = False | |||
| while True: | |||
| for dominant in self.areas: | |||
| fuse_areas = selector(dominant) | |||
| if fuse_areas: | |||
| for area in fuse_areas: | |||
| changed = True | |||
| dominant.fuse(area) | |||
| self.areas.remove(area) | |||
| break | |||
| else: | |||
| changed = False | |||
| return changed | |||
| def to_subgraphs(self): | |||
| """Transform op groups to subgraphs""" | |||
| ids = {} | |||
| for i, op in enumerate(self.graph.ops): | |||
| ids[op] = i | |||
| subgraphs = [] | |||
| for i, group in enumerate(self.groups): | |||
| group.sort(key=lambda op: self.ids[op]) | |||
| subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), group)) | |||
| return subgraphs | |||
| graphmodes = [] | |||
| for i, area in enumerate(self.areas): | |||
| area.ops.sort(key=lambda op: ids[op]) | |||
| subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops)) | |||
| graphmodes.append("basic" if area.mode == self.Area.MODE_BASIC else "composite") | |||
| return subgraphs, graphmodes | |||
| def split(self): | |||
| """Split graph""" | |||
| def _buddy(op, dom, path_ops): | |||
| """Fuse buddy together""" | |||
| group = self.op_group[op] | |||
| for p in group: | |||
| # p is buddy | |||
| if p.output.buddy is not None and p.output.buddy.members[0].op not in group: | |||
| """Split graph by pattern""" | |||
| def _elemwise_depth(dom): | |||
| if dom.pattern > PrimLib.BROADCAST or len(dom.in_relations) != 1: | |||
| return None | |||
| a, r = list(dom.in_relations.items())[0] | |||
| if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 and r != PrimLib.ELEMWISE: | |||
| return None | |||
| return [a] | |||
| def _elemwise_width(dom): | |||
| if dom.pattern > PrimLib.BROADCAST: | |||
| return None | |||
| fused = [] | |||
| for a, r in dom.in_relations.items(): | |||
| if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_circle(dom): | |||
| fused.append(a) | |||
| return fused | |||
| def _broadcast_depth(dom): | |||
| if dom.pattern > PrimLib.BROADCAST or len(dom.in_relations) != 1: | |||
| return None | |||
| a, r = list(dom.in_relations.items())[0] | |||
| if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or \ | |||
| r != PrimLib.BROADCAST or len(a.ops) > self.BORADCAST_FUSE_DEPTH: | |||
| return None | |||
| return [a] | |||
| def _broadcast_width(dom): | |||
| if dom.pattern > PrimLib.BROADCAST: | |||
| return None | |||
| fused = [] | |||
| for a, r in dom.in_relations.items(): | |||
| if a.pattern <= PrimLib.BROADCAST and r == PrimLib.BROADCAST and \ | |||
| a.check_circle(dom) and len(a.ops) <= self.BORADCAST_FUSE_DEPTH: | |||
| fused.append(a) | |||
| return fused | |||
| def _check_reduce_exclude(dom): | |||
| # exclude large all-reduce | |||
| if len(dom.ops[0].inputs[0].shape) == len(dom.ops[0].attrs["reduce_axis"]) and \ | |||
| dom.ops[0].inputs[0].get_size() > 10000: | |||
| return True | |||
| # exclude multi output | |||
| for a in dom.in_relations.keys(): | |||
| if len(a.out_relations) > 1: | |||
| return True | |||
| if any([op.output.para_type == Tensor.PARA_OUTPUT for op in a.ops]): | |||
| return True | |||
| # p's output is buddy | |||
| for to in p.output.to_ops: | |||
| if to.output.buddy is not None and to not in group: | |||
| return True | |||
| return False | |||
| def _injective(pattern, limit): | |||
| def _checker(op, dom, path_ops): | |||
| for p in op.output.to_ops: | |||
| if p not in self.op_group[dom]: | |||
| return False | |||
| if PrimLib.iter_type(op) in (PrimLib.ELEMWISE, PrimLib.BROADCAST): | |||
| for i, t in enumerate(dom.inputs): | |||
| if t == op.output: | |||
| return self.get_pattern(dom, i) == pattern and len(self.op_group[op]) < limit | |||
| return False | |||
| return _checker | |||
| def _reduce_depth(dom): | |||
| if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1: | |||
| return None | |||
| if _check_reduce_exclude(dom): | |||
| return None | |||
| a, r = list(dom.in_relations.items())[0] | |||
| if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or \ | |||
| r > PrimLib.REDUCE or len(a.ops) > self.REDUCE_FUSE_DEPTH: | |||
| return None | |||
| return [a] | |||
| def _diamond(op, dom, path_ops): | |||
| if PrimLib.iter_type(op) not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \ | |||
| PrimLib.iter_type(dom) in (PrimLib.UNKNOWN, PrimLib.TRANSFORM): | |||
| return False | |||
| return len(path_ops) == 1 and op.output not in dom.inputs | |||
| self.fuse(_buddy) | |||
| self.fuse(_injective(PrimLib.ELEMWISE, 100)) | |||
| self.fuse(_injective(PrimLib.BROADCAST, 6)) | |||
| self.fuse(_injective(PrimLib.REDUCE, 6)) | |||
| self.fuse(_diamond) | |||
| return self.to_subgraphs() | |||
| def _reduce_width(dom): | |||
| if dom.pattern != PrimLib.REDUCE: | |||
| return None | |||
| if _check_reduce_exclude(dom): | |||
| return None | |||
| fused = [] | |||
| for a, r in dom.in_relations.items(): | |||
| if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.REDUCE and \ | |||
| a.check_circle(dom) and len(a.ops) <= self.REDUCE_FUSE_DEPTH: | |||
| fused.append(a) | |||
| return fused | |||
| changed = True | |||
| while changed: | |||
| changed = self.fuse(_elemwise_depth) | |||
| changed = self.fuse(_elemwise_width) or changed | |||
| changed = self.fuse(_broadcast_depth) or changed | |||
| changed = self.fuse(_broadcast_width) or changed | |||
| changed = self.fuse(_reduce_depth) or changed | |||
| changed = self.fuse(_reduce_width) or changed | |||
| subgraphs, graphmodes = self.to_subgraphs() | |||
| return subgraphs, graphmodes | |||
| def split(graph): | |||
| """Split graph""" | |||
| return GraphSplitByPattern(graph).split() | |||
| @@ -196,8 +196,7 @@ class CompositeGraph: | |||
| shape, dtype, data_format, name=name, para_type=Tensor.PARA_OUTPUT) | |||
| cur_fusion = None | |||
| for op in desc['op_desc']: | |||
| inputs = [self.tensors[d[0]['tensor_name']] | |||
| for d in op['input_desc'] if 'value' not in d[0]] | |||
| inputs = [self.tensors[d['tensor_name']] for x in op['input_desc'] for d in x if 'value' not in d] | |||
| out_desc = op['output_desc'] | |||
| name, shape, dtype, data_format = out_desc[0]['tensor_name'], out_desc[ | |||
| 0]['shape'], out_desc[0]['data_type'], out_desc[0]['format'] | |||
| @@ -263,7 +262,7 @@ class CompositeGraph: | |||
| self.tensors[y], True) | |||
| inplace_desc = copy.deepcopy(d) | |||
| inplace_desc['attr'] = {'name': 'fake_output', 'value': fake} | |||
| z_desc, out_desc = inplace_desc['input_desc'][2][0].inplace_desc['output_desc'][0] | |||
| z_desc, out_desc = inplace_desc['input_desc'][2][0], inplace_desc['output_desc'][0] | |||
| z_desc['shape'] = z.shape | |||
| z_desc['data_type'] = z.dtype | |||
| z_desc['tensor_name'] = z.name | |||
| @@ -26,10 +26,12 @@ def split_with_json(json_str: str): | |||
| try: | |||
| graph_desc = json.loads(json_str) | |||
| comp = model.load_composite(graph_desc) | |||
| graph_split = model.split(comp.graph) | |||
| graph_split, graph_mode = model.split(comp.graph) | |||
| is_multi_graph = len(graph_split) > 1 | |||
| graph_list = list(map(comp.dump, graph_split)) | |||
| result = {"multi_graph": is_multi_graph, "graph_desc": graph_list} | |||
| result = {"multi_graph": is_multi_graph, | |||
| "graph_desc": graph_list, | |||
| "graph_mode": graph_mode} | |||
| return json.dumps(result) | |||
| except jd.JSONDecodeError: | |||
| logger.error(traceback.format_exc()) | |||
| @@ -1,53 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """test split""" | |||
| import model | |||
| def graph_1(): | |||
| gb = model.GraphBuilder() | |||
| with gb.graph_scope("main"): | |||
| a = gb.tensor([1024, 16], "float32", name="a") | |||
| b = gb.emit("Abs", a, 'b') | |||
| c = gb.emit("Abs", b, 'c') | |||
| d = gb.emit("Abs", c, 'd') | |||
| gb.emit("TensorAdd", [b, d], "e") | |||
| return gb.get()[0] | |||
| def graph_2(): | |||
| gb = model.GraphBuilder() | |||
| with gb.graph_scope("main"): | |||
| a = gb.tensor([1024, 16], "float32", name="a") | |||
| b = gb.emit("Abs", a, 'b') | |||
| c = gb.emit("Abs", b, 'c') | |||
| d = gb.emit("ReduceSum", c, 'd', attrs={'reduce_axis': (1,)}) | |||
| gb.emit("Sqrt", d, 'e') | |||
| return gb.get()[0] | |||
| def test_split_by_pattern(): | |||
| def _test(graph): | |||
| print("***************** main graph ***************") | |||
| print(graph) | |||
| subgraphs = model.split(graph) | |||
| for i, g in enumerate(subgraphs): | |||
| print('------------- subgraph {} --------------'.format(i)) | |||
| print(g) | |||
| _test(graph_2()) | |||
| if __name__ == '__main__': | |||
| test_split_by_pattern() | |||
| @@ -485,7 +485,7 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf | |||
| (*kernel_json)[kJsonKeyPlatform] = "AKG"; | |||
| (*kernel_json)[kJsonKeyProcess] = GetProcessorStr(anf_nodes[0]); | |||
| (*kernel_json)[kJsonKeyComposite] = true; | |||
| (*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString(); | |||
| (*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString() + "." + fg->debug_info()->get_id(); | |||
| if (!GetIOSize(*kernel_json, &input_size_list_, &output_size_list_)) { | |||
| MS_LOG(ERROR) << "Cal mem size failed."; | |||
| @@ -37,22 +37,17 @@ namespace opt { | |||
| namespace { | |||
| bool IsBasicOp(const AnfNodePtr &node, bool is_before_kernel_select) { | |||
| #if ENABLE_D | |||
| std::vector<PrimitivePtr> fusable_basic_ops = {prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub, | |||
| std::vector<PrimitivePtr> fusible_basic_ops = {prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub, | |||
| prim::kPrimExpandDims}; | |||
| if (!is_before_kernel_select) { | |||
| fusable_basic_ops.push_back(prim::kPrimCast); | |||
| fusible_basic_ops.push_back(prim::kPrimCast); | |||
| } | |||
| #elif ENABLE_GPU | |||
| std::vector<PrimitivePtr> fusable_basic_ops = { | |||
| prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd, | |||
| prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, | |||
| prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast, | |||
| prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, | |||
| prim::kPrimGreater, prim::kPrimAssign}; | |||
| std::vector<PrimitivePtr> fusible_basic_ops = GetFusibleOpList(); | |||
| #else | |||
| std::vector<PrimitivePtr> fusable_basic_ops; | |||
| std::vector<PrimitivePtr> fusible_basic_ops; | |||
| #endif | |||
| return std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(), | |||
| return std::any_of(fusible_basic_ops.begin(), fusible_basic_ops.end(), | |||
| [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); | |||
| } | |||
| @@ -49,12 +49,7 @@ bool IsBasicFuseOp(const AnfNodePtr &node, bool is_before_kernel_select) { | |||
| basic_ops.push_back(prim::kPrimCast); | |||
| } | |||
| #elif ENABLE_GPU | |||
| std::vector<PrimitivePtr> basic_ops = { | |||
| prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd, | |||
| prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, | |||
| prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast, | |||
| prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, | |||
| prim::kPrimGreater, prim::kPrimAssign}; | |||
| std::vector<PrimitivePtr> basic_ops = GetFusibleOpList(); | |||
| #else | |||
| std::vector<PrimitivePtr> basic_ops; | |||
| #endif | |||
| @@ -26,8 +26,8 @@ | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "ir/func_graph.h" | |||
| #include "backend/optimizer/pass/const_input_to_attr_registry.h" | |||
| #ifdef ENABLE_D | |||
| #include "backend/kernel_compiler/tbe/tbe_kernel_build.h" | |||
| #if ENABLE_GPU | |||
| #include "runtime/device/gpu/kernel_info_setter.h" | |||
| #endif | |||
| namespace mindspore { | |||
| @@ -612,36 +612,6 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNo | |||
| return new_fg; | |||
| } | |||
| bool JsonDescToAnf(const std::string &json_desc, const std::map<std::string, AnfNodePtr> &address_node_map, | |||
| std::vector<AnfNodePtrList> *res_graphs) { | |||
| MS_EXCEPTION_IF_NULL(res_graphs); | |||
| auto kernel_json = nlohmann::json::parse(json_desc); | |||
| if (kernel_json.find(kJsonKeyMultiGraph) == kernel_json.end() || kernel_json[kJsonKeyMultiGraph].is_null()) { | |||
| // not multi graphs. | |||
| MS_LOG(ERROR) << "Input json is not multi graph, " << json_desc; | |||
| return false; | |||
| } | |||
| kernel::AkgKernelJsonDecoder akg_kernel_json_decoder; | |||
| std::vector<nlohmann::json> graph_descs = kernel_json[kJsonKeyGraphDesc]; | |||
| if (graph_descs.empty()) { | |||
| MS_LOG(ERROR) << "No sub graph found, " << json_desc; | |||
| return false; | |||
| } | |||
| for (size_t i = 0; i < graph_descs.size(); ++i) { | |||
| const auto &graph_desc = graph_descs[i]; | |||
| AnfNodePtrList res_graph; | |||
| if (!akg_kernel_json_decoder.DecodeSplitNodes(graph_desc, address_node_map, &res_graph)) { | |||
| MS_LOG(ERROR) << "Failed decode sub graph, " << graph_desc; | |||
| return false; | |||
| } | |||
| res_graphs->push_back(res_graph); | |||
| } | |||
| return true; | |||
| } | |||
| std::unordered_set<PrimitivePtr> GetExpandOps() { | |||
| std::unordered_set<PrimitivePtr> expand_ops = { | |||
| prim::kPrimSquare, | |||
| @@ -664,5 +634,23 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p | |||
| } | |||
| return name.str(); | |||
| } | |||
| std::vector<PrimitivePtr> GetFusibleOpList() { | |||
| std::vector<PrimitivePtr> fusible_basic_ops = { | |||
| prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd, | |||
| prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, | |||
| prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast, | |||
| prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, | |||
| prim::kPrimGreater, prim::kPrimAssign, prim::kPrimReduceSum}; | |||
| return fusible_basic_ops; | |||
| } | |||
| void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| #if ENABLE_GPU | |||
| device::gpu::SetKernelInfo(cnode, kernel_type); | |||
| #endif | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -35,6 +35,7 @@ constexpr auto kGraphKernelSplitFunc = "split_with_json"; | |||
| constexpr auto kGetGraphKernelOpExpander = "get_op_expander"; | |||
| constexpr auto kJsonKeyMultiGraph = "multi_graph"; | |||
| constexpr auto kJsonKeyGraphDesc = "graph_desc"; | |||
| constexpr auto kJsonKeyGraphMode = "graph_mode"; | |||
| void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, | |||
| const AnfNodePtrList &outputs, kernel::Processor processor); | |||
| @@ -50,10 +51,10 @@ bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, n | |||
| std::map<std::string, AnfNodePtr> *address_node_map = nullptr); | |||
| bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, const DumpOption &dump_option, nlohmann::json *op_desc); | |||
| FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNodePtr> &inputs); | |||
| bool JsonDescToAnf(const std::string &json_desc, const std::map<std::string, AnfNodePtr> &address_node_map, | |||
| std::vector<AnfNodePtrList> *res_graphs); | |||
| std::unordered_set<PrimitivePtr> GetExpandOps(); | |||
| std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix = "", const string &postfix = ""); | |||
| std::vector<PrimitivePtr> GetFusibleOpList(); | |||
| void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_ | |||
| @@ -26,6 +26,7 @@ | |||
| #include "pipeline/jit/parse/python_adapter.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| #include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| @@ -203,7 +204,7 @@ class AreaGraph { | |||
| } | |||
| SortCNodes(main_cnodes); | |||
| cnode_group_id->swap(topo_order_); // The topo_order is not used anymore. | |||
| *cnode_group_id = std::move(topo_order_); // The topo_order is not used anymore. | |||
| return; | |||
| } | |||
| @@ -291,7 +292,7 @@ class AreaGraph { | |||
| std::vector<CNodePtr> main_cnodes_sorted; | |||
| std::transform(topo_order_.begin(), topo_order_.end(), std::back_inserter(main_cnodes_sorted), | |||
| [main_cnodes](int index) { return main_cnodes->at(index); }); | |||
| main_cnodes->swap(main_cnodes_sorted); | |||
| *main_cnodes = std::move(main_cnodes_sorted); | |||
| } | |||
| // Areas in this subgraph | |||
| @@ -415,6 +416,9 @@ class Splitter { | |||
| cnode->set_input(i, iter->second); | |||
| } | |||
| } | |||
| if (AnfAlgo::IsRealKernel(node)) { | |||
| ResetKernelInfo(node); | |||
| } | |||
| } | |||
| } | |||
| return output; | |||
| @@ -445,7 +449,7 @@ class Splitter { | |||
| tmp_subgraph_cnodes.push_back(new_subgraph_cnodes_[i]); | |||
| } | |||
| } | |||
| new_subgraph_cnodes_.swap(tmp_subgraph_cnodes); | |||
| new_subgraph_cnodes_ = std::move(tmp_subgraph_cnodes); | |||
| TraverseFuncGraph(main_func_graph_, [&replace_map](const AnfNodePtr &node) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| @@ -580,15 +584,38 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer { | |||
| return false; | |||
| } | |||
| // recover json to anf-ir. | |||
| split_plan_.clear(); | |||
| if (!JsonDescToAnf(split_graphs_str, address_node_map, &split_plan_)) { | |||
| MS_LOG(ERROR) << "Failed to decode split graphs."; | |||
| if (!DecodeJson(split_graphs_str, address_node_map)) { | |||
| MS_LOG(ERROR) << "Failed to decode split graphs. input json:\n" << split_graphs_str; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| virtual bool DecodeJson(const std::string &json_desc, const std::map<std::string, AnfNodePtr> &address_node_map) { | |||
| auto kernel_json = nlohmann::json::parse(json_desc); | |||
| kernel::AkgKernelJsonDecoder akg_kernel_json_decoder; | |||
| std::vector<nlohmann::json> graph_descs = kernel_json[kJsonKeyGraphDesc]; | |||
| std::vector<std::string> graph_modes = kernel_json[kJsonKeyGraphMode]; | |||
| if (graph_modes.size() != graph_descs.size()) { | |||
| MS_LOG(ERROR) << "Size of graph_mode " << graph_modes.size() << " mismatch graph_desc " << graph_descs.size(); | |||
| return false; | |||
| } | |||
| // recover json to anfnode. | |||
| split_plan_.clear(); | |||
| for (const auto &graph_desc : graph_descs) { | |||
| AnfNodePtrList res_graph; | |||
| if (!akg_kernel_json_decoder.DecodeSplitNodes(graph_desc, address_node_map, &res_graph)) { | |||
| MS_LOG(ERROR) << "Failed decode sub graph, " << graph_desc; | |||
| return false; | |||
| } | |||
| split_plan_.push_back(std::move(res_graph)); | |||
| } | |||
| // The info should be returned from costmodel. | |||
| need_inline_.assign(split_plan_.size(), 0); | |||
| // ops to be inlined. | |||
| need_inline_.clear(); | |||
| std::transform(graph_modes.begin(), graph_modes.end(), std::back_inserter(need_inline_), | |||
| [](const std::string &mode) { return mode == "basic" ? 1 : 0; }); | |||
| return true; | |||
| } | |||
| @@ -13,5 +13,5 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| PYTHONPATH="$(pwd)/..:${PYTHONPATH}" | |||
| PYTHONPATH="$(pwd)/../../../../mindspore/_extends/graph_kernel:${PYTHONPATH}" | |||
| export PYTHONPATH | |||
| @@ -0,0 +1,436 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """Test split""" | |||
| import model | |||
| from model import model as estimate | |||
| from model import graph_split as split | |||
| def get_nodes(sp, ops): | |||
| """Get nodes""" | |||
| if isinstance(ops[0], str): | |||
| new_ops = [] | |||
| for t in ops: | |||
| for op in sp.graph.ops: | |||
| if op.output.name == t: | |||
| new_ops.append(op) | |||
| break | |||
| else: | |||
| print("ERROR: not found op: ", t) | |||
| ops = new_ops | |||
| return [sp.nodes[sp.graph.ops.index(op)] for op in ops] | |||
| def first_connected(sp, space): | |||
| for cand in space: | |||
| nodes = [sp.nodes[i] for i in cand[0]] | |||
| graphs = sp.resolve_connnected_graphs(nodes) | |||
| if len(graphs) != 1: | |||
| print("connect check faied: ", nodes) | |||
| return False | |||
| return True | |||
| def split_format(sp, cand): | |||
| names = [] | |||
| for ids in cand: | |||
| ops = [] | |||
| for i in ids: | |||
| ops.append(sp.graph.ops[i].output.name) | |||
| names.append(','.join(ops)) | |||
| return '|'.join(names) | |||
| def graph_1(): | |||
| ''' ring, no succ_dep, no prev ''' | |||
| gb = model.GraphBuilder() | |||
| with gb.graph_scope("main"): | |||
| a = gb.tensor([10240, 16], "float32", name="a") | |||
| b = gb.emit("Abs", a, 'b') | |||
| c = gb.emit("Abs", b, 'c') | |||
| d = gb.emit("Abs", c, 'd') | |||
| gb.emit('TensorAdd', [b, d], 'e') | |||
| return gb.get()[0] | |||
| def graph_2(): | |||
| ''' ring, succ_dep, no prev ''' | |||
| gb = model.GraphBuilder() | |||
| with gb.graph_scope("main"): | |||
| a0 = gb.tensor([10240, 16], "float32", name="a0") | |||
| a = gb.emit("Abs", a0, 'a') | |||
| b = gb.emit("Abs", a, 'b') | |||
| c = gb.emit("Abs", a, 'c') | |||
| d = gb.emit("Abs", b, 'd') | |||
| e = gb.emit('TensorAdd', [c, d], 'e') | |||
| gb.emit("Abs", e, 'f') | |||
| return gb.get()[0] | |||
| def graph_3(): | |||
| ''' no ring, 1 sibling node ''' | |||
| gb = model.GraphBuilder() | |||
| with gb.graph_scope("main"): | |||
| a0 = gb.tensor([10240, 16], "float32", name="a0") | |||
| a1 = gb.tensor([10240, 16], "float32", name="a1") | |||
| b = gb.emit("Abs", a0, 'b') | |||
| c = gb.emit("Abs", a1, 'c') | |||
| d = gb.emit("Abs", b, 'd') | |||
| e = gb.emit('TensorAdd', [c, d], 'e') | |||
| gb.emit("Abs", e, 'f') | |||
| return gb.get()[0] | |||
| def graph_4(): | |||
| ''' no ring, 2 sibling nodes in 1 step ''' | |||
| gb = model.GraphBuilder() | |||
| with gb.graph_scope("main"): | |||
| a0 = gb.tensor([10240, 16], "float32", name="a0") | |||
| a1 = gb.tensor([10240, 16], "float32", name="a1") | |||
| b = gb.emit("Abs", a0, 'b') | |||
| c = gb.emit("Abs", b, 'c') | |||
| d = gb.emit("Abs", a1, 'd') | |||
| e = gb.emit("Abs", d, 'e') | |||
| f = gb.emit('TensorAdd', [c, e], 'f') | |||
| gb.emit('Abs', f, 'g') | |||
| h = gb.emit("Abs", d, 'h') | |||
| i = gb.emit('TensorAdd', [c, h], 'i') | |||
| gb.emit("Abs", i, 'j') | |||
| return gb.get()[0] | |||
| def graph_5(): | |||
| ''' no ring, 2 sibling step ''' | |||
| gb = model.GraphBuilder() | |||
| with gb.graph_scope("main") as g: | |||
| a0 = gb.tensor([10240, 16], "float32", name="a0") | |||
| a1 = gb.tensor([10240, 16], "float32", name="a1") | |||
| a2 = gb.tensor([10240, 16], "float32", name="a2") | |||
| a = gb.emit("Abs", a0, 'a') | |||
| b = gb.emit("Abs", a1, 'b') | |||
| c = gb.emit("Abs", b, 'c') | |||
| d = gb.emit('TensorAdd', [a, c], 'd') | |||
| gb.emit("Abs", d, 'e') | |||
| f = gb.emit("Abs", a2, 'f') | |||
| g = gb.emit('TensorAdd', [c, f], 'g') | |||
| gb.emit("Abs", g, 'h') | |||
| return gb.get()[0] | |||
| def graph_6(): | |||
| ''' no ring, tree down ''' | |||
| gb = model.GraphBuilder() | |||
| with gb.graph_scope("main"): | |||
| a0 = gb.tensor([10240, 16], "float32", name="a0") | |||
| a = gb.emit("Abs", a0, 'a') | |||
| b = gb.emit("Abs", a, 'b') | |||
| gb.emit("Abs", b, 'd') | |||
| gb.emit("Abs", b, 'e') | |||
| c = gb.emit("Abs", a, 'c') | |||
| gb.emit("Abs", c, 'f') | |||
| gb.emit("Abs", c, 'g') | |||
| return gb.get()[0] | |||
| def graph_pat_1(): | |||
| ''' split by reduce ''' | |||
| gb = model.GraphBuilder() | |||
| with gb.graph_scope("main"): | |||
| a0 = gb.tensor([1024, 1024], "float32", name="a0") | |||
| a = gb.emit("Abs", a0, 'a') | |||
| b = gb.emit("Abs", a, 'b') | |||
| c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)}) | |||
| d = gb.emit("Sqrt", c, 'd') | |||
| gb.emit("Sqrt", d, 'f') | |||
| return gb.get()[0] | |||
| def graph_pat_2(): | |||
| ''' multi output ''' | |||
| gb = model.GraphBuilder() | |||
| with gb.graph_scope("main"): | |||
| a0 = gb.tensor([1024, 1024], "float32", name="a0") | |||
| a = gb.emit("Abs", a0, 'a') | |||
| b = gb.emit("Abs", a, 'b') | |||
| gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)}) | |||
| gb.emit("ReduceSum", b, 'e', attrs={'reduce_axis': (1,)}) | |||
| return gb.get()[0] | |||
| def graph_pat_3(): | |||
| ''' two reduce ''' | |||
| gb = model.GraphBuilder() | |||
| with gb.graph_scope("main"): | |||
| a0 = gb.tensor([1024, 1024], "float32", name="a0") | |||
| a = gb.emit("Abs", a0, 'a') | |||
| b = gb.emit("Abs", a, 'b') | |||
| c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)}) | |||
| d = gb.emit("Abs", c, 'd') | |||
| gb.emit("ReduceSum", d, 'e', attrs={'reduce_axis': (1,)}) | |||
| return gb.get()[0] | |||
| def graph_pat_4(): | |||
| ''' elewise + broadcast ''' | |||
| gb = model.GraphBuilder() | |||
| with gb.graph_scope("main"): | |||
| a0 = gb.tensor([1, 1024], "float32", name="a0") | |||
| a2 = gb.tensor([1014, 1024], "float32", name="a2") | |||
| a = gb.emit("Abs", a0, 'a') | |||
| b = gb.emit("Abs", a, 'b') | |||
| c = gb.emit("Abs", b, 'c') | |||
| d = gb.emit("Abs", c, 'd') | |||
| e = gb.emit("Abs", d, 'e') | |||
| f = gb.emit("Abs", e, 'f') | |||
| g0 = gb.emit("Abs", a2, 'g0') | |||
| # g0 = gb.emit("Abs", g0, 'g0') | |||
| # g0 = gb.emit("Abs", g0, 'g0') | |||
| # g0 = gb.emit("Abs", g0, 'g0') | |||
| # g0 = gb.emit("Abs", g0, 'g0') | |||
| # g0 = gb.emit("Abs", g0, 'g0') | |||
| # g0 = gb.emit("Abs", g0, 'g0') | |||
| g0 = gb.emit("Abs", g0, 'g0') | |||
| g1 = gb.emit('TensorAdd', [f, g0], 'g1') | |||
| g2 = gb.emit("Abs", g1, 'g2') | |||
| g3 = gb.emit("Abs", g2, 'g3') | |||
| g4 = gb.emit("Abs", g3, 'g4') | |||
| gb.emit("Abs", g4, 'g5') | |||
| return gb.get()[0] | |||
| def graph_pat_5(): | |||
| ''' reduce + reshape ''' | |||
| gb = model.GraphBuilder() | |||
| with gb.graph_scope("main"): | |||
| a0 = gb.tensor([1024, 1024], "float32", name="a0") | |||
| a = gb.emit("Abs", a0, 'a') | |||
| b = gb.emit("Abs", a, 'b') | |||
| c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)}) | |||
| d = gb.emit("Abs", c, 'd') | |||
| e = gb.tensor([512, 2048], "float32", name="e") | |||
| gb.op("Reshape", e, [d]) | |||
| return gb.get()[0] | |||
| def graph_pat_6(): | |||
| ''' dimond ''' | |||
| gb = model.GraphBuilder() | |||
| with gb.graph_scope("main"): | |||
| a0 = gb.tensor([1024, 1024], "float32", name="a0") | |||
| a = gb.emit("Abs", a0, 'a') | |||
| b = gb.emit("Abs", a, 'b') | |||
| c = gb.emit("Abs", a, 'c') | |||
| gb.emit("TensorAdd", [b, c], 'd') | |||
| gb.emit("Abs", c, 'f') # broke dimond | |||
| return gb.get()[0] | |||
| def graph_pat_7(): | |||
| ''' buddy of control op ''' | |||
| gb = model.GraphBuilder() | |||
| with gb.graph_scope("main"): | |||
| a0 = gb.tensor([1024, 1024], "float32", name="a0") | |||
| a1 = gb.tensor([1024, 1024], "float32", name="a1") | |||
| a = gb.emit("Abs", a0, 'a') | |||
| b = gb.emit("Abs", a1, 'b') | |||
| c = gb.emit("make_tuple", [a, b], 'c') | |||
| d = gb.tensor([1024, 1024], "float32", name="d") | |||
| gb.op("AddN", d, [c]) | |||
| gb.emit("Abs", d, 'f') | |||
| graph = gb.get()[0] | |||
| estimate.AddControlBuddy().visit_graph(graph) | |||
| return graph | |||
| def graph_pat_8(): | |||
| ''' reduce + reshape ''' | |||
| gb = model.GraphBuilder() | |||
| with gb.graph_scope("main"): | |||
| a0 = gb.tensor([1024, 1024], "float32", name="a0") | |||
| a = gb.emit("Abs", a0, 'a') | |||
| b = gb.emit("Abs", a, 'b') | |||
| #c = gb.emit("Abs", b, 'b') | |||
| c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)}) | |||
| gb.emit("TensorAdd", [b, c], 'd') | |||
| return gb.get()[0] | |||
| def graph_pat_9(): | |||
| ''' scalar ''' | |||
| gb = model.GraphBuilder() | |||
| with gb.graph_scope("main"): | |||
| a0 = gb.tensor([1024, 1024], "float32", name="a0") | |||
| a1 = gb.tensor([1], "float32", name="a1") | |||
| a = gb.emit("Maximum", a1, 'a') | |||
| b = gb.emit("Mul", [a, a1], 'b') | |||
| gb.emit('Mul', [b, a0], 'c') | |||
| return gb.get()[0] | |||
| def graph_mo_1(): | |||
| gb = model.GraphBuilder() | |||
| with gb.graph_scope("main"): | |||
| a0 = gb.tensor([1024, 1024], "float32", name="a0") | |||
| a = gb.emit("Abs", a0, 'a') | |||
| gb.emit("Abs", a, 'b') | |||
| gb.emit("Abs", a, 'c') | |||
| return gb.get()[0] | |||
| def graph_mo_2(): | |||
| gb = model.GraphBuilder() | |||
| with gb.graph_scope("main") as g: | |||
| a0 = gb.tensor([1024, 1024], "float32", name="a0") | |||
| a = gb.emit("Abs", a0, 'a') | |||
| b = gb.emit("Abs", a, 'b') | |||
| c = gb.emit("Abs", b, 'c') | |||
| g.set_output(b, c) | |||
| return gb.get()[0] | |||
| def graph_mo_3(): | |||
| ''' two reduce ''' | |||
| gb = model.GraphBuilder() | |||
| with gb.graph_scope("main") as g: | |||
| a0 = gb.tensor([1024, 1024], "float32", name="a0") | |||
| a = gb.emit("Abs", a0, 'a') | |||
| b = gb.emit("Abs", a, 'b') | |||
| c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)}) | |||
| g.set_output(b, c) | |||
| return gb.get()[0] | |||
| def graph_mo_4(): | |||
| ''' two reduce ''' | |||
| gb = model.GraphBuilder() | |||
| with gb.graph_scope("main") as g: | |||
| a0 = gb.tensor([1024, 1024], "float32", name="a0") | |||
| a = gb.emit("Abs", a0, 'a') | |||
| b = gb.emit("Abs", a, 'b') | |||
| c = gb.emit("ReduceSum", a, 'c', attrs={'reduce_axis': (1,)}) | |||
| g.set_output(b, c) | |||
| return gb.get()[0] | |||
| def test_binary_split(): | |||
| """Test binary split""" | |||
| def _test(graph, expected_space_size): | |||
| print("********* test on graph : {} *************".format(graph.name)) | |||
| sp = split.GraphSpliter(graph) | |||
| nodes = get_nodes(sp, graph.ops) | |||
| space = sp.binary_split(nodes) | |||
| for i, s in enumerate(space): | |||
| print('{}: {}'.format(i, split_format(sp, s))) | |||
| assert len(space) == expected_space_size | |||
| assert first_connected(sp, space) | |||
| _test(graph_1(), 3) | |||
| _test(graph_2(), 7) | |||
| _test(graph_3(), 4) | |||
| _test(graph_4(), 17) | |||
| _test(graph_5(), 11) | |||
| _test(graph_6(), 24) | |||
| def test_resolve_connnected_graphs(): | |||
| """Test resolve connected graphs""" | |||
| graph = graph_5() | |||
| sp = split.GraphSpliter(graph) | |||
| n1 = get_nodes(sp, ['a', 'd', 'b', 'c']) | |||
| graphs = sp.resolve_connnected_graphs(n1) | |||
| print(graphs) | |||
| assert len(graphs) == 1 | |||
| n2 = get_nodes(sp, ['a', 'd', 'e', 'f', 'g']) | |||
| graphs = sp.resolve_connnected_graphs(n2) | |||
| print(graphs) | |||
| assert len(graphs) == 2 | |||
| n3 = get_nodes(sp, ['a', 'b', 'f']) | |||
| graphs = sp.resolve_connnected_graphs(n3) | |||
| print(graphs) | |||
| assert len(graphs) == 3 | |||
| def test_split(): | |||
| """Test split""" | |||
| def _print_cost(name, c): | |||
| print("%s\tdma_ratio=%f, saturation=%f, mix_saturation=%f, type=%s" % | |||
| (name, c.dma_ratio(), c.saturation(), c.mix_saturation(), c.cost_type())) | |||
| def _test(graph): | |||
| print("********* test on graph : {} *************".format(graph.name)) | |||
| sp = split.GraphSpliter(graph) | |||
| subgraphs = sp.split(False) | |||
| print('----- main graph -------') | |||
| print(graph) | |||
| for i, g in enumerate(subgraphs): | |||
| print(' -------- subgraph {} -------'.format(i)) | |||
| print(g) | |||
| print("--------- cost ------------") | |||
| cost, _ = model.estimate(graph) | |||
| _print_cost("main graph", cost) | |||
| fc, sub_costs = model.estimate(subgraphs) | |||
| _print_cost("Subgraphs:", fc) | |||
| for i, cost in enumerate(sub_costs): | |||
| _print_cost(" |_%d:\t" % (i), cost) | |||
| _test(graph_5()) | |||
| # _test(graph_4()) | |||
| def test_estimate(): | |||
| """Test estimate""" | |||
| graph = graph_5() | |||
| e = estimate.Estimator(graph) | |||
| e.estimate() | |||
| print(e.iter_space) | |||
| def test_pattern_split(): | |||
| """Test pattern split""" | |||
| def _test(graph, expect_n=0): | |||
| print("************* main graph **************") | |||
| print(graph) | |||
| subgraphs = split.GraphSplitByPatternV2(graph).split() | |||
| for i, g in enumerate(subgraphs): | |||
| print(' -------- subgraph {} -------'.format(i)) | |||
| print(g) | |||
| if expect_n > 0: | |||
| assert len(subgraphs) == expect_n | |||
| # _test(graph_1(), 1) | |||
| # _test(graph_pat_1(), 2) | |||
| # _test(graph_pat_2()) | |||
| # _test(graph_pat_3()) | |||
| # _test(graph_pat_4()) | |||
| # _test(graph_pat_5()) | |||
| # _test(graph_pat_6()) | |||
| # _test(graph_pat_7()) | |||
| # _test(graph_pat_8()) | |||
| # _test(graph_pat_9()) | |||
| # _test(graph_mo_1()) | |||
| # _test(graph_mo_2()) | |||
| # _test(graph_mo_3()) | |||
| _test(graph_mo_4()) | |||
| def main(): | |||
| # test_binary_split() | |||
| # test_resolve_connnected_graphs() | |||
| # test_split() | |||
| # test_estimate() | |||
| test_pattern_split() | |||
| if __name__ == '__main__': | |||
| main() | |||