Merge pull request !7013 from DeshiChen/0928_costmodel_multioutputtags/v1.1.0
| @@ -14,138 +14,221 @@ | |||||
| # =========================================================================== | # =========================================================================== | ||||
| """Cost model splitter""" | """Cost model splitter""" | ||||
| from .model import PrimLib, Graph | |||||
| from .model import PrimLib, Graph, Tensor | |||||
| class GraphSplitByPattern: | class GraphSplitByPattern: | ||||
| """Graph split by pattern""" | |||||
| """Graph splitter""" | |||||
| class Area: | |||||
| """Area""" | |||||
| MODE_BASIC = 1 | |||||
| MODE_COMPOSITE = 2 | |||||
| def __init__(self, init_op): | |||||
| self.pattern = PrimLib.iter_type(init_op) | |||||
| self.ops = [init_op] | |||||
| self.in_relations = dict() # {area1: relation1, area2: relation2, ...} | |||||
| self.out_relations = dict() # {area1: relation1, area2: relation2, ...} | |||||
| self.mode = self.MODE_BASIC | |||||
| def __str__(self): | |||||
| return '<' + '-'.join([op.output.name for op in self.ops]) + '>' | |||||
| def __repr__(self): | |||||
| return str(self) | |||||
| def link_input(self, area_map): | |||||
| """Link inputs""" | |||||
| def get_relation(op, i): | |||||
| relation = PrimLib.UNKNOWN | |||||
| _, elem_relation = PrimLib.input_relation(op, i) | |||||
| for r in elem_relation: | |||||
| if r is not None and r > relation: | |||||
| relation = r | |||||
| return relation | |||||
| for i, t in enumerate(self.ops[0].inputs): | |||||
| if t.op is not None: | |||||
| area, relation = area_map[t.op], get_relation(self.ops[0], i) | |||||
| self.in_relations[area] = relation | |||||
| def link_output(self): | |||||
| """Link outputs""" | |||||
| for input_area, r in self.in_relations.items(): | |||||
| input_area.out_relations[self] = r | |||||
| def fuse(self, area): | |||||
| """Fuse `area` to `self`""" | |||||
| def _update_relation(relations, a, r): | |||||
| relations[a] = max(r, relations[a]) if a in relations else r | |||||
| def _update_pattern(): | |||||
| self.pattern = max(self.pattern, area.pattern, self.in_relations[area]) | |||||
| def _fuse_relation(self_relations, new_relations): | |||||
| for a, r in new_relations.items(): | |||||
| if a != self: | |||||
| _update_relation(self_relations, a, r) | |||||
| if area in self_relations: | |||||
| self_relations.pop(area) | |||||
| def _redirect_relation(rels): | |||||
| """Replace `area` with `self` in relations""" | |||||
| if area in rels: | |||||
| r = rels.pop(area) | |||||
| _update_relation(rels, self, r) | |||||
| self.ops.extend(area.ops) | |||||
| _update_pattern() | |||||
| _fuse_relation(self.in_relations, area.in_relations) | |||||
| _fuse_relation(self.out_relations, area.out_relations) | |||||
| for a, _ in area.in_relations.items(): | |||||
| _redirect_relation(a.out_relations) | |||||
| for a, _ in area.out_relations.items(): | |||||
| _redirect_relation(a.in_relations) | |||||
| self.mode = self.MODE_COMPOSITE | |||||
| def check_circle(self, to): | |||||
| """Check circle. It returns false if circle exists""" | |||||
| def _reached(area, to): | |||||
| for out, _ in area.out_relations.items(): | |||||
| if out == to or _reached(out, to): | |||||
| return True | |||||
| return False | |||||
| for out, _ in self.out_relations.items(): | |||||
| if out != to and _reached(out, to): | |||||
| return False | |||||
| return True | |||||
| BORADCAST_FUSE_DEPTH = 3 | |||||
| REDUCE_FUSE_DEPTH = 3 | |||||
| def __init__(self, graph): | def __init__(self, graph): | ||||
| self.graph = graph | self.graph = graph | ||||
| self.groups = [] | |||||
| self.op_group = {} | |||||
| for op in self.graph.ops: | |||||
| g = [op] | |||||
| self.groups.append(g) | |||||
| self.op_group[op] = g | |||||
| self.ids = {} | |||||
| for i, op in enumerate(graph.ops): | |||||
| self.ids[op] = i | |||||
| self.doms = self.post_dom(graph.ops) | |||||
| _, outputs = graph.deduce_parameters() | |||||
| self.outputs = set(outputs) | |||||
| def post_dom(self, ops): | |||||
| """Post dom""" | |||||
| doms, i_doms = {}, {} | |||||
| for i in range(len(ops) - 1, -1, -1): | |||||
| op = ops[i] | |||||
| doms[op] = {op} | |||||
| i_dom = None | |||||
| if op.output.to_ops: | |||||
| suc_dom = set(doms[op.output.to_ops[0]]) | |||||
| for to in op.output.to_ops[1:]: | |||||
| suc_dom.intersection_update(doms[to]) | |||||
| doms[op].update(suc_dom) | |||||
| for dom in suc_dom: | |||||
| if i_dom is None or self.ids[dom] < self.ids[i_dom]: | |||||
| i_dom = dom | |||||
| i_doms[op] = i_dom | |||||
| return i_doms | |||||
| def get_pattern(self, op, i): | |||||
| """Get pattern""" | |||||
| pattern = PrimLib.UNKNOWN | |||||
| _, elem_relation = PrimLib.input_relation(op, i) | |||||
| for pat in elem_relation: | |||||
| if pat and pat > pattern: | |||||
| pattern = pat | |||||
| return pattern | |||||
| def fuse(self, check_fun): | |||||
| """Fuse ops""" | |||||
| def _get_path(op, dom): | |||||
| path_ops, visited = [], set() | |||||
| def _get_path_depth(p): | |||||
| visited.add(p) | |||||
| if self.op_group[p][0] == p: | |||||
| path_ops.append(p) | |||||
| for to in p.output.to_ops: | |||||
| if to != dom and to not in visited: | |||||
| _get_path_depth(to) | |||||
| _get_path_depth(op) | |||||
| return path_ops | |||||
| changed = True | |||||
| while changed: | |||||
| for group in self.groups: | |||||
| op = group[0] | |||||
| dom = self.doms[op] | |||||
| if dom is None or op.output in self.outputs: | |||||
| continue | |||||
| ops = _get_path(op, dom) | |||||
| if check_fun(op, dom, ops): | |||||
| dom_group = self.op_group[dom] | |||||
| fused = [] | |||||
| for fop in ops: | |||||
| f_group = self.op_group[fop] | |||||
| for p in f_group: | |||||
| self.op_group[p] = dom_group | |||||
| fused.append(f_group) | |||||
| dom_group += f_group | |||||
| for g in fused: | |||||
| self.groups.remove(g) | |||||
| self.areas = [] | |||||
| area_map = {} | |||||
| for op in graph.ops: | |||||
| a = self.Area(op) | |||||
| self.areas.append(a) | |||||
| area_map[op] = a | |||||
| for a in self.areas: | |||||
| a.link_input(area_map) | |||||
| for a in self.areas: | |||||
| a.link_output() | |||||
| def fuse(self, selector): | |||||
| """Fuse areas""" | |||||
| changed = False | |||||
| while True: | |||||
| for dominant in self.areas: | |||||
| fuse_areas = selector(dominant) | |||||
| if fuse_areas: | |||||
| for area in fuse_areas: | |||||
| changed = True | |||||
| dominant.fuse(area) | |||||
| self.areas.remove(area) | |||||
| break | break | ||||
| else: | else: | ||||
| changed = False | |||||
| return changed | |||||
| def to_subgraphs(self): | def to_subgraphs(self): | ||||
| """Transform op groups to subgraphs""" | """Transform op groups to subgraphs""" | ||||
| ids = {} | |||||
| for i, op in enumerate(self.graph.ops): | |||||
| ids[op] = i | |||||
| subgraphs = [] | subgraphs = [] | ||||
| for i, group in enumerate(self.groups): | |||||
| group.sort(key=lambda op: self.ids[op]) | |||||
| subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), group)) | |||||
| return subgraphs | |||||
| graphmodes = [] | |||||
| for i, area in enumerate(self.areas): | |||||
| area.ops.sort(key=lambda op: ids[op]) | |||||
| subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops)) | |||||
| graphmodes.append("basic" if area.mode == self.Area.MODE_BASIC else "composite") | |||||
| return subgraphs, graphmodes | |||||
| def split(self): | def split(self): | ||||
| """Split graph""" | |||||
| def _buddy(op, dom, path_ops): | |||||
| """Fuse buddy together""" | |||||
| group = self.op_group[op] | |||||
| for p in group: | |||||
| # p is buddy | |||||
| if p.output.buddy is not None and p.output.buddy.members[0].op not in group: | |||||
| """Split graph by pattern""" | |||||
| def _elemwise_depth(dom): | |||||
| if dom.pattern > PrimLib.BROADCAST or len(dom.in_relations) != 1: | |||||
| return None | |||||
| a, r = list(dom.in_relations.items())[0] | |||||
| if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 and r != PrimLib.ELEMWISE: | |||||
| return None | |||||
| return [a] | |||||
| def _elemwise_width(dom): | |||||
| if dom.pattern > PrimLib.BROADCAST: | |||||
| return None | |||||
| fused = [] | |||||
| for a, r in dom.in_relations.items(): | |||||
| if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_circle(dom): | |||||
| fused.append(a) | |||||
| return fused | |||||
| def _broadcast_depth(dom): | |||||
| if dom.pattern > PrimLib.BROADCAST or len(dom.in_relations) != 1: | |||||
| return None | |||||
| a, r = list(dom.in_relations.items())[0] | |||||
| if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or \ | |||||
| r != PrimLib.BROADCAST or len(a.ops) > self.BORADCAST_FUSE_DEPTH: | |||||
| return None | |||||
| return [a] | |||||
| def _broadcast_width(dom): | |||||
| if dom.pattern > PrimLib.BROADCAST: | |||||
| return None | |||||
| fused = [] | |||||
| for a, r in dom.in_relations.items(): | |||||
| if a.pattern <= PrimLib.BROADCAST and r == PrimLib.BROADCAST and \ | |||||
| a.check_circle(dom) and len(a.ops) <= self.BORADCAST_FUSE_DEPTH: | |||||
| fused.append(a) | |||||
| return fused | |||||
| def _check_reduce_exclude(dom): | |||||
| # exclude large all-reduce | |||||
| if len(dom.ops[0].inputs[0].shape) == len(dom.ops[0].attrs["reduce_axis"]) and \ | |||||
| dom.ops[0].inputs[0].get_size() > 10000: | |||||
| return True | |||||
| # exclude multi output | |||||
| for a in dom.in_relations.keys(): | |||||
| if len(a.out_relations) > 1: | |||||
| return True | |||||
| if any([op.output.para_type == Tensor.PARA_OUTPUT for op in a.ops]): | |||||
| return True | return True | ||||
| # p's output is buddy | |||||
| for to in p.output.to_ops: | |||||
| if to.output.buddy is not None and to not in group: | |||||
| return True | |||||
| return False | return False | ||||
| def _injective(pattern, limit): | |||||
| def _checker(op, dom, path_ops): | |||||
| for p in op.output.to_ops: | |||||
| if p not in self.op_group[dom]: | |||||
| return False | |||||
| if PrimLib.iter_type(op) in (PrimLib.ELEMWISE, PrimLib.BROADCAST): | |||||
| for i, t in enumerate(dom.inputs): | |||||
| if t == op.output: | |||||
| return self.get_pattern(dom, i) == pattern and len(self.op_group[op]) < limit | |||||
| return False | |||||
| return _checker | |||||
| def _reduce_depth(dom): | |||||
| if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1: | |||||
| return None | |||||
| if _check_reduce_exclude(dom): | |||||
| return None | |||||
| a, r = list(dom.in_relations.items())[0] | |||||
| if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or \ | |||||
| r > PrimLib.REDUCE or len(a.ops) > self.REDUCE_FUSE_DEPTH: | |||||
| return None | |||||
| return [a] | |||||
| def _diamond(op, dom, path_ops): | |||||
| if PrimLib.iter_type(op) not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \ | |||||
| PrimLib.iter_type(dom) in (PrimLib.UNKNOWN, PrimLib.TRANSFORM): | |||||
| return False | |||||
| return len(path_ops) == 1 and op.output not in dom.inputs | |||||
| self.fuse(_buddy) | |||||
| self.fuse(_injective(PrimLib.ELEMWISE, 100)) | |||||
| self.fuse(_injective(PrimLib.BROADCAST, 6)) | |||||
| self.fuse(_injective(PrimLib.REDUCE, 6)) | |||||
| self.fuse(_diamond) | |||||
| return self.to_subgraphs() | |||||
| def _reduce_width(dom): | |||||
| if dom.pattern != PrimLib.REDUCE: | |||||
| return None | |||||
| if _check_reduce_exclude(dom): | |||||
| return None | |||||
| fused = [] | |||||
| for a, r in dom.in_relations.items(): | |||||
| if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.REDUCE and \ | |||||
| a.check_circle(dom) and len(a.ops) <= self.REDUCE_FUSE_DEPTH: | |||||
| fused.append(a) | |||||
| return fused | |||||
| changed = True | |||||
| while changed: | |||||
| changed = self.fuse(_elemwise_depth) | |||||
| changed = self.fuse(_elemwise_width) or changed | |||||
| changed = self.fuse(_broadcast_depth) or changed | |||||
| changed = self.fuse(_broadcast_width) or changed | |||||
| changed = self.fuse(_reduce_depth) or changed | |||||
| changed = self.fuse(_reduce_width) or changed | |||||
| subgraphs, graphmodes = self.to_subgraphs() | |||||
| return subgraphs, graphmodes | |||||
| def split(graph): | def split(graph): | ||||
| """Split graph""" | |||||
| return GraphSplitByPattern(graph).split() | return GraphSplitByPattern(graph).split() | ||||
| @@ -196,8 +196,7 @@ class CompositeGraph: | |||||
| shape, dtype, data_format, name=name, para_type=Tensor.PARA_OUTPUT) | shape, dtype, data_format, name=name, para_type=Tensor.PARA_OUTPUT) | ||||
| cur_fusion = None | cur_fusion = None | ||||
| for op in desc['op_desc']: | for op in desc['op_desc']: | ||||
| inputs = [self.tensors[d[0]['tensor_name']] | |||||
| for d in op['input_desc'] if 'value' not in d[0]] | |||||
| inputs = [self.tensors[d['tensor_name']] for x in op['input_desc'] for d in x if 'value' not in d] | |||||
| out_desc = op['output_desc'] | out_desc = op['output_desc'] | ||||
| name, shape, dtype, data_format = out_desc[0]['tensor_name'], out_desc[ | name, shape, dtype, data_format = out_desc[0]['tensor_name'], out_desc[ | ||||
| 0]['shape'], out_desc[0]['data_type'], out_desc[0]['format'] | 0]['shape'], out_desc[0]['data_type'], out_desc[0]['format'] | ||||
| @@ -263,7 +262,7 @@ class CompositeGraph: | |||||
| self.tensors[y], True) | self.tensors[y], True) | ||||
| inplace_desc = copy.deepcopy(d) | inplace_desc = copy.deepcopy(d) | ||||
| inplace_desc['attr'] = {'name': 'fake_output', 'value': fake} | inplace_desc['attr'] = {'name': 'fake_output', 'value': fake} | ||||
| z_desc, out_desc = inplace_desc['input_desc'][2][0].inplace_desc['output_desc'][0] | |||||
| z_desc, out_desc = inplace_desc['input_desc'][2][0], inplace_desc['output_desc'][0] | |||||
| z_desc['shape'] = z.shape | z_desc['shape'] = z.shape | ||||
| z_desc['data_type'] = z.dtype | z_desc['data_type'] = z.dtype | ||||
| z_desc['tensor_name'] = z.name | z_desc['tensor_name'] = z.name | ||||
| @@ -26,10 +26,12 @@ def split_with_json(json_str: str): | |||||
| try: | try: | ||||
| graph_desc = json.loads(json_str) | graph_desc = json.loads(json_str) | ||||
| comp = model.load_composite(graph_desc) | comp = model.load_composite(graph_desc) | ||||
| graph_split = model.split(comp.graph) | |||||
| graph_split, graph_mode = model.split(comp.graph) | |||||
| is_multi_graph = len(graph_split) > 1 | is_multi_graph = len(graph_split) > 1 | ||||
| graph_list = list(map(comp.dump, graph_split)) | graph_list = list(map(comp.dump, graph_split)) | ||||
| result = {"multi_graph": is_multi_graph, "graph_desc": graph_list} | |||||
| result = {"multi_graph": is_multi_graph, | |||||
| "graph_desc": graph_list, | |||||
| "graph_mode": graph_mode} | |||||
| return json.dumps(result) | return json.dumps(result) | ||||
| except jd.JSONDecodeError: | except jd.JSONDecodeError: | ||||
| logger.error(traceback.format_exc()) | logger.error(traceback.format_exc()) | ||||
| @@ -1,53 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # =========================================================================== | |||||
| """test split""" | |||||
| import model | |||||
| def graph_1(): | |||||
| gb = model.GraphBuilder() | |||||
| with gb.graph_scope("main"): | |||||
| a = gb.tensor([1024, 16], "float32", name="a") | |||||
| b = gb.emit("Abs", a, 'b') | |||||
| c = gb.emit("Abs", b, 'c') | |||||
| d = gb.emit("Abs", c, 'd') | |||||
| gb.emit("TensorAdd", [b, d], "e") | |||||
| return gb.get()[0] | |||||
| def graph_2(): | |||||
| gb = model.GraphBuilder() | |||||
| with gb.graph_scope("main"): | |||||
| a = gb.tensor([1024, 16], "float32", name="a") | |||||
| b = gb.emit("Abs", a, 'b') | |||||
| c = gb.emit("Abs", b, 'c') | |||||
| d = gb.emit("ReduceSum", c, 'd', attrs={'reduce_axis': (1,)}) | |||||
| gb.emit("Sqrt", d, 'e') | |||||
| return gb.get()[0] | |||||
| def test_split_by_pattern(): | |||||
| def _test(graph): | |||||
| print("***************** main graph ***************") | |||||
| print(graph) | |||||
| subgraphs = model.split(graph) | |||||
| for i, g in enumerate(subgraphs): | |||||
| print('------------- subgraph {} --------------'.format(i)) | |||||
| print(g) | |||||
| _test(graph_2()) | |||||
| if __name__ == '__main__': | |||||
| test_split_by_pattern() | |||||
| @@ -485,7 +485,7 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf | |||||
| (*kernel_json)[kJsonKeyPlatform] = "AKG"; | (*kernel_json)[kJsonKeyPlatform] = "AKG"; | ||||
| (*kernel_json)[kJsonKeyProcess] = GetProcessorStr(anf_nodes[0]); | (*kernel_json)[kJsonKeyProcess] = GetProcessorStr(anf_nodes[0]); | ||||
| (*kernel_json)[kJsonKeyComposite] = true; | (*kernel_json)[kJsonKeyComposite] = true; | ||||
| (*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString(); | |||||
| (*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString() + "." + fg->debug_info()->get_id(); | |||||
| if (!GetIOSize(*kernel_json, &input_size_list_, &output_size_list_)) { | if (!GetIOSize(*kernel_json, &input_size_list_, &output_size_list_)) { | ||||
| MS_LOG(ERROR) << "Cal mem size failed."; | MS_LOG(ERROR) << "Cal mem size failed."; | ||||
| @@ -37,22 +37,17 @@ namespace opt { | |||||
| namespace { | namespace { | ||||
| bool IsBasicOp(const AnfNodePtr &node, bool is_before_kernel_select) { | bool IsBasicOp(const AnfNodePtr &node, bool is_before_kernel_select) { | ||||
| #if ENABLE_D | #if ENABLE_D | ||||
| std::vector<PrimitivePtr> fusable_basic_ops = {prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub, | |||||
| std::vector<PrimitivePtr> fusible_basic_ops = {prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub, | |||||
| prim::kPrimExpandDims}; | prim::kPrimExpandDims}; | ||||
| if (!is_before_kernel_select) { | if (!is_before_kernel_select) { | ||||
| fusable_basic_ops.push_back(prim::kPrimCast); | |||||
| fusible_basic_ops.push_back(prim::kPrimCast); | |||||
| } | } | ||||
| #elif ENABLE_GPU | #elif ENABLE_GPU | ||||
| std::vector<PrimitivePtr> fusable_basic_ops = { | |||||
| prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd, | |||||
| prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, | |||||
| prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast, | |||||
| prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, | |||||
| prim::kPrimGreater, prim::kPrimAssign}; | |||||
| std::vector<PrimitivePtr> fusible_basic_ops = GetFusibleOpList(); | |||||
| #else | #else | ||||
| std::vector<PrimitivePtr> fusable_basic_ops; | |||||
| std::vector<PrimitivePtr> fusible_basic_ops; | |||||
| #endif | #endif | ||||
| return std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(), | |||||
| return std::any_of(fusible_basic_ops.begin(), fusible_basic_ops.end(), | |||||
| [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); | [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); | ||||
| } | } | ||||
| @@ -49,12 +49,7 @@ bool IsBasicFuseOp(const AnfNodePtr &node, bool is_before_kernel_select) { | |||||
| basic_ops.push_back(prim::kPrimCast); | basic_ops.push_back(prim::kPrimCast); | ||||
| } | } | ||||
| #elif ENABLE_GPU | #elif ENABLE_GPU | ||||
| std::vector<PrimitivePtr> basic_ops = { | |||||
| prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd, | |||||
| prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, | |||||
| prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast, | |||||
| prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, | |||||
| prim::kPrimGreater, prim::kPrimAssign}; | |||||
| std::vector<PrimitivePtr> basic_ops = GetFusibleOpList(); | |||||
| #else | #else | ||||
| std::vector<PrimitivePtr> basic_ops; | std::vector<PrimitivePtr> basic_ops; | ||||
| #endif | #endif | ||||
| @@ -26,8 +26,8 @@ | |||||
| #include "ir/func_graph_cloner.h" | #include "ir/func_graph_cloner.h" | ||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "backend/optimizer/pass/const_input_to_attr_registry.h" | #include "backend/optimizer/pass/const_input_to_attr_registry.h" | ||||
| #ifdef ENABLE_D | |||||
| #include "backend/kernel_compiler/tbe/tbe_kernel_build.h" | |||||
| #if ENABLE_GPU | |||||
| #include "runtime/device/gpu/kernel_info_setter.h" | |||||
| #endif | #endif | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -612,36 +612,6 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNo | |||||
| return new_fg; | return new_fg; | ||||
| } | } | ||||
| bool JsonDescToAnf(const std::string &json_desc, const std::map<std::string, AnfNodePtr> &address_node_map, | |||||
| std::vector<AnfNodePtrList> *res_graphs) { | |||||
| MS_EXCEPTION_IF_NULL(res_graphs); | |||||
| auto kernel_json = nlohmann::json::parse(json_desc); | |||||
| if (kernel_json.find(kJsonKeyMultiGraph) == kernel_json.end() || kernel_json[kJsonKeyMultiGraph].is_null()) { | |||||
| // not multi graphs. | |||||
| MS_LOG(ERROR) << "Input json is not multi graph, " << json_desc; | |||||
| return false; | |||||
| } | |||||
| kernel::AkgKernelJsonDecoder akg_kernel_json_decoder; | |||||
| std::vector<nlohmann::json> graph_descs = kernel_json[kJsonKeyGraphDesc]; | |||||
| if (graph_descs.empty()) { | |||||
| MS_LOG(ERROR) << "No sub graph found, " << json_desc; | |||||
| return false; | |||||
| } | |||||
| for (size_t i = 0; i < graph_descs.size(); ++i) { | |||||
| const auto &graph_desc = graph_descs[i]; | |||||
| AnfNodePtrList res_graph; | |||||
| if (!akg_kernel_json_decoder.DecodeSplitNodes(graph_desc, address_node_map, &res_graph)) { | |||||
| MS_LOG(ERROR) << "Failed decode sub graph, " << graph_desc; | |||||
| return false; | |||||
| } | |||||
| res_graphs->push_back(res_graph); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| std::unordered_set<PrimitivePtr> GetExpandOps() { | std::unordered_set<PrimitivePtr> GetExpandOps() { | ||||
| std::unordered_set<PrimitivePtr> expand_ops = { | std::unordered_set<PrimitivePtr> expand_ops = { | ||||
| prim::kPrimSquare, | prim::kPrimSquare, | ||||
| @@ -664,5 +634,23 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p | |||||
| } | } | ||||
| return name.str(); | return name.str(); | ||||
| } | } | ||||
| std::vector<PrimitivePtr> GetFusibleOpList() { | |||||
| std::vector<PrimitivePtr> fusible_basic_ops = { | |||||
| prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd, | |||||
| prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, | |||||
| prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast, | |||||
| prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, | |||||
| prim::kPrimGreater, prim::kPrimAssign, prim::kPrimReduceSum}; | |||||
| return fusible_basic_ops; | |||||
| } | |||||
| void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) { | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| #if ENABLE_GPU | |||||
| device::gpu::SetKernelInfo(cnode, kernel_type); | |||||
| #endif | |||||
| } | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -35,6 +35,7 @@ constexpr auto kGraphKernelSplitFunc = "split_with_json"; | |||||
| constexpr auto kGetGraphKernelOpExpander = "get_op_expander"; | constexpr auto kGetGraphKernelOpExpander = "get_op_expander"; | ||||
| constexpr auto kJsonKeyMultiGraph = "multi_graph"; | constexpr auto kJsonKeyMultiGraph = "multi_graph"; | ||||
| constexpr auto kJsonKeyGraphDesc = "graph_desc"; | constexpr auto kJsonKeyGraphDesc = "graph_desc"; | ||||
| constexpr auto kJsonKeyGraphMode = "graph_mode"; | |||||
| void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, | void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, | ||||
| const AnfNodePtrList &outputs, kernel::Processor processor); | const AnfNodePtrList &outputs, kernel::Processor processor); | ||||
| @@ -50,10 +51,10 @@ bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, n | |||||
| std::map<std::string, AnfNodePtr> *address_node_map = nullptr); | std::map<std::string, AnfNodePtr> *address_node_map = nullptr); | ||||
| bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, const DumpOption &dump_option, nlohmann::json *op_desc); | bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, const DumpOption &dump_option, nlohmann::json *op_desc); | ||||
| FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNodePtr> &inputs); | FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNodePtr> &inputs); | ||||
| bool JsonDescToAnf(const std::string &json_desc, const std::map<std::string, AnfNodePtr> &address_node_map, | |||||
| std::vector<AnfNodePtrList> *res_graphs); | |||||
| std::unordered_set<PrimitivePtr> GetExpandOps(); | std::unordered_set<PrimitivePtr> GetExpandOps(); | ||||
| std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix = "", const string &postfix = ""); | std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix = "", const string &postfix = ""); | ||||
| std::vector<PrimitivePtr> GetFusibleOpList(); | |||||
| void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_ | #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_ | ||||
| @@ -26,6 +26,7 @@ | |||||
| #include "pipeline/jit/parse/python_adapter.h" | #include "pipeline/jit/parse/python_adapter.h" | ||||
| #include "backend/session/anf_runtime_algorithm.h" | #include "backend/session/anf_runtime_algorithm.h" | ||||
| #include "backend/kernel_compiler/common_utils.h" | #include "backend/kernel_compiler/common_utils.h" | ||||
| #include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h" | |||||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | ||||
| #include "debug/anf_ir_dump.h" | #include "debug/anf_ir_dump.h" | ||||
| @@ -203,7 +204,7 @@ class AreaGraph { | |||||
| } | } | ||||
| SortCNodes(main_cnodes); | SortCNodes(main_cnodes); | ||||
| cnode_group_id->swap(topo_order_); // The topo_order is not used anymore. | |||||
| *cnode_group_id = std::move(topo_order_); // The topo_order is not used anymore. | |||||
| return; | return; | ||||
| } | } | ||||
| @@ -291,7 +292,7 @@ class AreaGraph { | |||||
| std::vector<CNodePtr> main_cnodes_sorted; | std::vector<CNodePtr> main_cnodes_sorted; | ||||
| std::transform(topo_order_.begin(), topo_order_.end(), std::back_inserter(main_cnodes_sorted), | std::transform(topo_order_.begin(), topo_order_.end(), std::back_inserter(main_cnodes_sorted), | ||||
| [main_cnodes](int index) { return main_cnodes->at(index); }); | [main_cnodes](int index) { return main_cnodes->at(index); }); | ||||
| main_cnodes->swap(main_cnodes_sorted); | |||||
| *main_cnodes = std::move(main_cnodes_sorted); | |||||
| } | } | ||||
| // Areas in this subgraph | // Areas in this subgraph | ||||
| @@ -415,6 +416,9 @@ class Splitter { | |||||
| cnode->set_input(i, iter->second); | cnode->set_input(i, iter->second); | ||||
| } | } | ||||
| } | } | ||||
| if (AnfAlgo::IsRealKernel(node)) { | |||||
| ResetKernelInfo(node); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| return output; | return output; | ||||
| @@ -445,7 +449,7 @@ class Splitter { | |||||
| tmp_subgraph_cnodes.push_back(new_subgraph_cnodes_[i]); | tmp_subgraph_cnodes.push_back(new_subgraph_cnodes_[i]); | ||||
| } | } | ||||
| } | } | ||||
| new_subgraph_cnodes_.swap(tmp_subgraph_cnodes); | |||||
| new_subgraph_cnodes_ = std::move(tmp_subgraph_cnodes); | |||||
| TraverseFuncGraph(main_func_graph_, [&replace_map](const AnfNodePtr &node) { | TraverseFuncGraph(main_func_graph_, [&replace_map](const AnfNodePtr &node) { | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| @@ -580,15 +584,38 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer { | |||||
| return false; | return false; | ||||
| } | } | ||||
| // recover json to anf-ir. | |||||
| split_plan_.clear(); | |||||
| if (!JsonDescToAnf(split_graphs_str, address_node_map, &split_plan_)) { | |||||
| MS_LOG(ERROR) << "Failed to decode split graphs."; | |||||
| if (!DecodeJson(split_graphs_str, address_node_map)) { | |||||
| MS_LOG(ERROR) << "Failed to decode split graphs. input json:\n" << split_graphs_str; | |||||
| return false; | return false; | ||||
| } | } | ||||
| return true; | |||||
| } | |||||
| virtual bool DecodeJson(const std::string &json_desc, const std::map<std::string, AnfNodePtr> &address_node_map) { | |||||
| auto kernel_json = nlohmann::json::parse(json_desc); | |||||
| kernel::AkgKernelJsonDecoder akg_kernel_json_decoder; | |||||
| std::vector<nlohmann::json> graph_descs = kernel_json[kJsonKeyGraphDesc]; | |||||
| std::vector<std::string> graph_modes = kernel_json[kJsonKeyGraphMode]; | |||||
| if (graph_modes.size() != graph_descs.size()) { | |||||
| MS_LOG(ERROR) << "Size of graph_mode " << graph_modes.size() << " mismatch graph_desc " << graph_descs.size(); | |||||
| return false; | |||||
| } | |||||
| // recover json to anfnode. | |||||
| split_plan_.clear(); | |||||
| for (const auto &graph_desc : graph_descs) { | |||||
| AnfNodePtrList res_graph; | |||||
| if (!akg_kernel_json_decoder.DecodeSplitNodes(graph_desc, address_node_map, &res_graph)) { | |||||
| MS_LOG(ERROR) << "Failed decode sub graph, " << graph_desc; | |||||
| return false; | |||||
| } | |||||
| split_plan_.push_back(std::move(res_graph)); | |||||
| } | |||||
| // The info should be returned from costmodel. | |||||
| need_inline_.assign(split_plan_.size(), 0); | |||||
| // ops to be inlined. | |||||
| need_inline_.clear(); | |||||
| std::transform(graph_modes.begin(), graph_modes.end(), std::back_inserter(need_inline_), | |||||
| [](const std::string &mode) { return mode == "basic" ? 1 : 0; }); | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -13,5 +13,5 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| PYTHONPATH="$(pwd)/..:${PYTHONPATH}" | |||||
| PYTHONPATH="$(pwd)/../../../../mindspore/_extends/graph_kernel:${PYTHONPATH}" | |||||
| export PYTHONPATH | export PYTHONPATH | ||||
| @@ -0,0 +1,436 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # =========================================================================== | |||||
| """Test split""" | |||||
| import model | |||||
| from model import model as estimate | |||||
| from model import graph_split as split | |||||
| def get_nodes(sp, ops): | |||||
| """Get nodes""" | |||||
| if isinstance(ops[0], str): | |||||
| new_ops = [] | |||||
| for t in ops: | |||||
| for op in sp.graph.ops: | |||||
| if op.output.name == t: | |||||
| new_ops.append(op) | |||||
| break | |||||
| else: | |||||
| print("ERROR: not found op: ", t) | |||||
| ops = new_ops | |||||
| return [sp.nodes[sp.graph.ops.index(op)] for op in ops] | |||||
| def first_connected(sp, space): | |||||
| for cand in space: | |||||
| nodes = [sp.nodes[i] for i in cand[0]] | |||||
| graphs = sp.resolve_connnected_graphs(nodes) | |||||
| if len(graphs) != 1: | |||||
| print("connect check faied: ", nodes) | |||||
| return False | |||||
| return True | |||||
| def split_format(sp, cand): | |||||
| names = [] | |||||
| for ids in cand: | |||||
| ops = [] | |||||
| for i in ids: | |||||
| ops.append(sp.graph.ops[i].output.name) | |||||
| names.append(','.join(ops)) | |||||
| return '|'.join(names) | |||||
| def graph_1(): | |||||
| ''' ring, no succ_dep, no prev ''' | |||||
| gb = model.GraphBuilder() | |||||
| with gb.graph_scope("main"): | |||||
| a = gb.tensor([10240, 16], "float32", name="a") | |||||
| b = gb.emit("Abs", a, 'b') | |||||
| c = gb.emit("Abs", b, 'c') | |||||
| d = gb.emit("Abs", c, 'd') | |||||
| gb.emit('TensorAdd', [b, d], 'e') | |||||
| return gb.get()[0] | |||||
| def graph_2(): | |||||
| ''' ring, succ_dep, no prev ''' | |||||
| gb = model.GraphBuilder() | |||||
| with gb.graph_scope("main"): | |||||
| a0 = gb.tensor([10240, 16], "float32", name="a0") | |||||
| a = gb.emit("Abs", a0, 'a') | |||||
| b = gb.emit("Abs", a, 'b') | |||||
| c = gb.emit("Abs", a, 'c') | |||||
| d = gb.emit("Abs", b, 'd') | |||||
| e = gb.emit('TensorAdd', [c, d], 'e') | |||||
| gb.emit("Abs", e, 'f') | |||||
| return gb.get()[0] | |||||
| def graph_3(): | |||||
| ''' no ring, 1 sibling node ''' | |||||
| gb = model.GraphBuilder() | |||||
| with gb.graph_scope("main"): | |||||
| a0 = gb.tensor([10240, 16], "float32", name="a0") | |||||
| a1 = gb.tensor([10240, 16], "float32", name="a1") | |||||
| b = gb.emit("Abs", a0, 'b') | |||||
| c = gb.emit("Abs", a1, 'c') | |||||
| d = gb.emit("Abs", b, 'd') | |||||
| e = gb.emit('TensorAdd', [c, d], 'e') | |||||
| gb.emit("Abs", e, 'f') | |||||
| return gb.get()[0] | |||||
| def graph_4(): | |||||
| ''' no ring, 2 sibling nodes in 1 step ''' | |||||
| gb = model.GraphBuilder() | |||||
| with gb.graph_scope("main"): | |||||
| a0 = gb.tensor([10240, 16], "float32", name="a0") | |||||
| a1 = gb.tensor([10240, 16], "float32", name="a1") | |||||
| b = gb.emit("Abs", a0, 'b') | |||||
| c = gb.emit("Abs", b, 'c') | |||||
| d = gb.emit("Abs", a1, 'd') | |||||
| e = gb.emit("Abs", d, 'e') | |||||
| f = gb.emit('TensorAdd', [c, e], 'f') | |||||
| gb.emit('Abs', f, 'g') | |||||
| h = gb.emit("Abs", d, 'h') | |||||
| i = gb.emit('TensorAdd', [c, h], 'i') | |||||
| gb.emit("Abs", i, 'j') | |||||
| return gb.get()[0] | |||||
| def graph_5(): | |||||
| ''' no ring, 2 sibling step ''' | |||||
| gb = model.GraphBuilder() | |||||
| with gb.graph_scope("main") as g: | |||||
| a0 = gb.tensor([10240, 16], "float32", name="a0") | |||||
| a1 = gb.tensor([10240, 16], "float32", name="a1") | |||||
| a2 = gb.tensor([10240, 16], "float32", name="a2") | |||||
| a = gb.emit("Abs", a0, 'a') | |||||
| b = gb.emit("Abs", a1, 'b') | |||||
| c = gb.emit("Abs", b, 'c') | |||||
| d = gb.emit('TensorAdd', [a, c], 'd') | |||||
| gb.emit("Abs", d, 'e') | |||||
| f = gb.emit("Abs", a2, 'f') | |||||
| g = gb.emit('TensorAdd', [c, f], 'g') | |||||
| gb.emit("Abs", g, 'h') | |||||
| return gb.get()[0] | |||||
| def graph_6(): | |||||
| ''' no ring, tree down ''' | |||||
| gb = model.GraphBuilder() | |||||
| with gb.graph_scope("main"): | |||||
| a0 = gb.tensor([10240, 16], "float32", name="a0") | |||||
| a = gb.emit("Abs", a0, 'a') | |||||
| b = gb.emit("Abs", a, 'b') | |||||
| gb.emit("Abs", b, 'd') | |||||
| gb.emit("Abs", b, 'e') | |||||
| c = gb.emit("Abs", a, 'c') | |||||
| gb.emit("Abs", c, 'f') | |||||
| gb.emit("Abs", c, 'g') | |||||
| return gb.get()[0] | |||||
| def graph_pat_1(): | |||||
| ''' split by reduce ''' | |||||
| gb = model.GraphBuilder() | |||||
| with gb.graph_scope("main"): | |||||
| a0 = gb.tensor([1024, 1024], "float32", name="a0") | |||||
| a = gb.emit("Abs", a0, 'a') | |||||
| b = gb.emit("Abs", a, 'b') | |||||
| c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)}) | |||||
| d = gb.emit("Sqrt", c, 'd') | |||||
| gb.emit("Sqrt", d, 'f') | |||||
| return gb.get()[0] | |||||
| def graph_pat_2(): | |||||
| ''' multi output ''' | |||||
| gb = model.GraphBuilder() | |||||
| with gb.graph_scope("main"): | |||||
| a0 = gb.tensor([1024, 1024], "float32", name="a0") | |||||
| a = gb.emit("Abs", a0, 'a') | |||||
| b = gb.emit("Abs", a, 'b') | |||||
| gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)}) | |||||
| gb.emit("ReduceSum", b, 'e', attrs={'reduce_axis': (1,)}) | |||||
| return gb.get()[0] | |||||
| def graph_pat_3(): | |||||
| ''' two reduce ''' | |||||
| gb = model.GraphBuilder() | |||||
| with gb.graph_scope("main"): | |||||
| a0 = gb.tensor([1024, 1024], "float32", name="a0") | |||||
| a = gb.emit("Abs", a0, 'a') | |||||
| b = gb.emit("Abs", a, 'b') | |||||
| c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)}) | |||||
| d = gb.emit("Abs", c, 'd') | |||||
| gb.emit("ReduceSum", d, 'e', attrs={'reduce_axis': (1,)}) | |||||
| return gb.get()[0] | |||||
| def graph_pat_4(): | |||||
| ''' elewise + broadcast ''' | |||||
| gb = model.GraphBuilder() | |||||
| with gb.graph_scope("main"): | |||||
| a0 = gb.tensor([1, 1024], "float32", name="a0") | |||||
| a2 = gb.tensor([1014, 1024], "float32", name="a2") | |||||
| a = gb.emit("Abs", a0, 'a') | |||||
| b = gb.emit("Abs", a, 'b') | |||||
| c = gb.emit("Abs", b, 'c') | |||||
| d = gb.emit("Abs", c, 'd') | |||||
| e = gb.emit("Abs", d, 'e') | |||||
| f = gb.emit("Abs", e, 'f') | |||||
| g0 = gb.emit("Abs", a2, 'g0') | |||||
| # g0 = gb.emit("Abs", g0, 'g0') | |||||
| # g0 = gb.emit("Abs", g0, 'g0') | |||||
| # g0 = gb.emit("Abs", g0, 'g0') | |||||
| # g0 = gb.emit("Abs", g0, 'g0') | |||||
| # g0 = gb.emit("Abs", g0, 'g0') | |||||
| # g0 = gb.emit("Abs", g0, 'g0') | |||||
| g0 = gb.emit("Abs", g0, 'g0') | |||||
| g1 = gb.emit('TensorAdd', [f, g0], 'g1') | |||||
| g2 = gb.emit("Abs", g1, 'g2') | |||||
| g3 = gb.emit("Abs", g2, 'g3') | |||||
| g4 = gb.emit("Abs", g3, 'g4') | |||||
| gb.emit("Abs", g4, 'g5') | |||||
| return gb.get()[0] | |||||
| def graph_pat_5(): | |||||
| ''' reduce + reshape ''' | |||||
| gb = model.GraphBuilder() | |||||
| with gb.graph_scope("main"): | |||||
| a0 = gb.tensor([1024, 1024], "float32", name="a0") | |||||
| a = gb.emit("Abs", a0, 'a') | |||||
| b = gb.emit("Abs", a, 'b') | |||||
| c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)}) | |||||
| d = gb.emit("Abs", c, 'd') | |||||
| e = gb.tensor([512, 2048], "float32", name="e") | |||||
| gb.op("Reshape", e, [d]) | |||||
| return gb.get()[0] | |||||
| def graph_pat_6(): | |||||
| ''' dimond ''' | |||||
| gb = model.GraphBuilder() | |||||
| with gb.graph_scope("main"): | |||||
| a0 = gb.tensor([1024, 1024], "float32", name="a0") | |||||
| a = gb.emit("Abs", a0, 'a') | |||||
| b = gb.emit("Abs", a, 'b') | |||||
| c = gb.emit("Abs", a, 'c') | |||||
| gb.emit("TensorAdd", [b, c], 'd') | |||||
| gb.emit("Abs", c, 'f') # broke dimond | |||||
| return gb.get()[0] | |||||
| def graph_pat_7(): | |||||
| ''' buddy of control op ''' | |||||
| gb = model.GraphBuilder() | |||||
| with gb.graph_scope("main"): | |||||
| a0 = gb.tensor([1024, 1024], "float32", name="a0") | |||||
| a1 = gb.tensor([1024, 1024], "float32", name="a1") | |||||
| a = gb.emit("Abs", a0, 'a') | |||||
| b = gb.emit("Abs", a1, 'b') | |||||
| c = gb.emit("make_tuple", [a, b], 'c') | |||||
| d = gb.tensor([1024, 1024], "float32", name="d") | |||||
| gb.op("AddN", d, [c]) | |||||
| gb.emit("Abs", d, 'f') | |||||
| graph = gb.get()[0] | |||||
| estimate.AddControlBuddy().visit_graph(graph) | |||||
| return graph | |||||
| def graph_pat_8(): | |||||
| ''' reduce + reshape ''' | |||||
| gb = model.GraphBuilder() | |||||
| with gb.graph_scope("main"): | |||||
| a0 = gb.tensor([1024, 1024], "float32", name="a0") | |||||
| a = gb.emit("Abs", a0, 'a') | |||||
| b = gb.emit("Abs", a, 'b') | |||||
| #c = gb.emit("Abs", b, 'b') | |||||
| c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)}) | |||||
| gb.emit("TensorAdd", [b, c], 'd') | |||||
| return gb.get()[0] | |||||
| def graph_pat_9(): | |||||
| ''' scalar ''' | |||||
| gb = model.GraphBuilder() | |||||
| with gb.graph_scope("main"): | |||||
| a0 = gb.tensor([1024, 1024], "float32", name="a0") | |||||
| a1 = gb.tensor([1], "float32", name="a1") | |||||
| a = gb.emit("Maximum", a1, 'a') | |||||
| b = gb.emit("Mul", [a, a1], 'b') | |||||
| gb.emit('Mul', [b, a0], 'c') | |||||
| return gb.get()[0] | |||||
| def graph_mo_1(): | |||||
| gb = model.GraphBuilder() | |||||
| with gb.graph_scope("main"): | |||||
| a0 = gb.tensor([1024, 1024], "float32", name="a0") | |||||
| a = gb.emit("Abs", a0, 'a') | |||||
| gb.emit("Abs", a, 'b') | |||||
| gb.emit("Abs", a, 'c') | |||||
| return gb.get()[0] | |||||
| def graph_mo_2(): | |||||
| gb = model.GraphBuilder() | |||||
| with gb.graph_scope("main") as g: | |||||
| a0 = gb.tensor([1024, 1024], "float32", name="a0") | |||||
| a = gb.emit("Abs", a0, 'a') | |||||
| b = gb.emit("Abs", a, 'b') | |||||
| c = gb.emit("Abs", b, 'c') | |||||
| g.set_output(b, c) | |||||
| return gb.get()[0] | |||||
| def graph_mo_3(): | |||||
| ''' two reduce ''' | |||||
| gb = model.GraphBuilder() | |||||
| with gb.graph_scope("main") as g: | |||||
| a0 = gb.tensor([1024, 1024], "float32", name="a0") | |||||
| a = gb.emit("Abs", a0, 'a') | |||||
| b = gb.emit("Abs", a, 'b') | |||||
| c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)}) | |||||
| g.set_output(b, c) | |||||
| return gb.get()[0] | |||||
| def graph_mo_4(): | |||||
| ''' two reduce ''' | |||||
| gb = model.GraphBuilder() | |||||
| with gb.graph_scope("main") as g: | |||||
| a0 = gb.tensor([1024, 1024], "float32", name="a0") | |||||
| a = gb.emit("Abs", a0, 'a') | |||||
| b = gb.emit("Abs", a, 'b') | |||||
| c = gb.emit("ReduceSum", a, 'c', attrs={'reduce_axis': (1,)}) | |||||
| g.set_output(b, c) | |||||
| return gb.get()[0] | |||||
| def test_binary_split(): | |||||
| """Test binary split""" | |||||
| def _test(graph, expected_space_size): | |||||
| print("********* test on graph : {} *************".format(graph.name)) | |||||
| sp = split.GraphSpliter(graph) | |||||
| nodes = get_nodes(sp, graph.ops) | |||||
| space = sp.binary_split(nodes) | |||||
| for i, s in enumerate(space): | |||||
| print('{}: {}'.format(i, split_format(sp, s))) | |||||
| assert len(space) == expected_space_size | |||||
| assert first_connected(sp, space) | |||||
| _test(graph_1(), 3) | |||||
| _test(graph_2(), 7) | |||||
| _test(graph_3(), 4) | |||||
| _test(graph_4(), 17) | |||||
| _test(graph_5(), 11) | |||||
| _test(graph_6(), 24) | |||||
| def test_resolve_connnected_graphs(): | |||||
| """Test resolve connected graphs""" | |||||
| graph = graph_5() | |||||
| sp = split.GraphSpliter(graph) | |||||
| n1 = get_nodes(sp, ['a', 'd', 'b', 'c']) | |||||
| graphs = sp.resolve_connnected_graphs(n1) | |||||
| print(graphs) | |||||
| assert len(graphs) == 1 | |||||
| n2 = get_nodes(sp, ['a', 'd', 'e', 'f', 'g']) | |||||
| graphs = sp.resolve_connnected_graphs(n2) | |||||
| print(graphs) | |||||
| assert len(graphs) == 2 | |||||
| n3 = get_nodes(sp, ['a', 'b', 'f']) | |||||
| graphs = sp.resolve_connnected_graphs(n3) | |||||
| print(graphs) | |||||
| assert len(graphs) == 3 | |||||
| def test_split(): | |||||
| """Test split""" | |||||
| def _print_cost(name, c): | |||||
| print("%s\tdma_ratio=%f, saturation=%f, mix_saturation=%f, type=%s" % | |||||
| (name, c.dma_ratio(), c.saturation(), c.mix_saturation(), c.cost_type())) | |||||
| def _test(graph): | |||||
| print("********* test on graph : {} *************".format(graph.name)) | |||||
| sp = split.GraphSpliter(graph) | |||||
| subgraphs = sp.split(False) | |||||
| print('----- main graph -------') | |||||
| print(graph) | |||||
| for i, g in enumerate(subgraphs): | |||||
| print(' -------- subgraph {} -------'.format(i)) | |||||
| print(g) | |||||
| print("--------- cost ------------") | |||||
| cost, _ = model.estimate(graph) | |||||
| _print_cost("main graph", cost) | |||||
| fc, sub_costs = model.estimate(subgraphs) | |||||
| _print_cost("Subgraphs:", fc) | |||||
| for i, cost in enumerate(sub_costs): | |||||
| _print_cost(" |_%d:\t" % (i), cost) | |||||
| _test(graph_5()) | |||||
| # _test(graph_4()) | |||||
| def test_estimate(): | |||||
| """Test estimate""" | |||||
| graph = graph_5() | |||||
| e = estimate.Estimator(graph) | |||||
| e.estimate() | |||||
| print(e.iter_space) | |||||
| def test_pattern_split(): | |||||
| """Test pattern split""" | |||||
| def _test(graph, expect_n=0): | |||||
| print("************* main graph **************") | |||||
| print(graph) | |||||
| subgraphs = split.GraphSplitByPatternV2(graph).split() | |||||
| for i, g in enumerate(subgraphs): | |||||
| print(' -------- subgraph {} -------'.format(i)) | |||||
| print(g) | |||||
| if expect_n > 0: | |||||
| assert len(subgraphs) == expect_n | |||||
| # _test(graph_1(), 1) | |||||
| # _test(graph_pat_1(), 2) | |||||
| # _test(graph_pat_2()) | |||||
| # _test(graph_pat_3()) | |||||
| # _test(graph_pat_4()) | |||||
| # _test(graph_pat_5()) | |||||
| # _test(graph_pat_6()) | |||||
| # _test(graph_pat_7()) | |||||
| # _test(graph_pat_8()) | |||||
| # _test(graph_pat_9()) | |||||
| # _test(graph_mo_1()) | |||||
| # _test(graph_mo_2()) | |||||
| # _test(graph_mo_3()) | |||||
| _test(graph_mo_4()) | |||||
| def main(): | |||||
| # test_binary_split() | |||||
| # test_resolve_connnected_graphs() | |||||
| # test_split() | |||||
| # test_estimate() | |||||
| test_pattern_split() | |||||
| if __name__ == '__main__': | |||||
| main() | |||||