| @@ -16,7 +16,7 @@ | |||
| import os | |||
| from functools import reduce | |||
| from mindspore import log as logger | |||
| from .model import PrimLib, Graph, Tensor | |||
| from .model import PrimLib, Graph, Tensor, Operator | |||
| from .model import DataFormat as DF | |||
| @@ -65,13 +65,16 @@ class GraphSplitByPattern: | |||
| self.stitch_ops = set() | |||
| self.stitch_atomic_ops = set() | |||
| def __init__(self, init_op, is_output, unique_id, reach_tab): | |||
| self.pattern = PrimLib.iter_type(init_op) | |||
| self.ops = [init_op] | |||
| def __init__(self, init_op, is_output, unique_id, reach_tab, recompute_ops=None): | |||
| self.pattern = PrimLib.iter_type(init_op) if init_op is not None else PrimLib.UNKNOWN | |||
| self.ops = [] if init_op is None else [init_op] | |||
| self.in_relations = dict() # {area1: relation1, area2: relation2, ...} | |||
| self.out_relations = dict() # {area1: relation1, area2: relation2, ...} | |||
| self.mode = None | |||
| self.stitch_info = self.StitchInfo() | |||
| self.recompute_ops = [] if recompute_ops is None else recompute_ops | |||
| self.ori_op_map = {} | |||
| self.is_recompute = False | |||
| self.is_output = is_output | |||
| self.output_excluded = set() | |||
| if self.pattern == PrimLib.REDUCE: | |||
| @@ -143,6 +146,8 @@ class GraphSplitByPattern: | |||
| r = rels.pop(area) | |||
| _update_relation(rels, self, r) | |||
| if area.is_recompute: | |||
| self.cp_ops(area) | |||
| if self.pattern >= area.pattern: | |||
| self.ops.extend(area.ops) | |||
| else: | |||
| @@ -161,7 +166,9 @@ class GraphSplitByPattern: | |||
| if area.output_excluded: | |||
| self.output_excluded.update(area.output_excluded) | |||
| self.update_stitch_info(area.stitch_info) | |||
| self.reach_tab.fuse(self.unique_id, area.unique_id) | |||
| if not area.is_recompute: | |||
| self.reach_tab.fuse(self.unique_id, area.unique_id) | |||
| self.recompute_ops.extend(area.recompute_ops) | |||
| def check_acyclic(self, to): | |||
| """Check circle. It returns false if circle exists""" | |||
| @@ -180,25 +187,73 @@ class GraphSplitByPattern: | |||
| return True | |||
| return False | |||
| def cp_ops(self, area): | |||
| """copy recompute_ops in area to ops, self is area's user""" | |||
| tail_tensor = area.recompute_ops[-1].output | |||
| #copy tensors, all copied are Tensor.PARA_NONE | |||
| tensor_map = {} | |||
| tensor_map[area.recompute_ops[0].inputs[0]] = area.recompute_ops[0].inputs[0] | |||
| for op in area.recompute_ops: | |||
| orig_tensor = op.output | |||
| cp_tensor = Tensor(orig_tensor.name, orig_tensor.shape, orig_tensor.dtype, orig_tensor.data_format) | |||
| tensor_map[orig_tensor] = cp_tensor | |||
| #copy ops | |||
| cp_ops = [] | |||
| for op in area.recompute_ops: | |||
| cp_op = Operator(op.prim, [tensor_map[op.inputs[0]]], tensor_map[op.output], op.attrs) | |||
| cp_op.all_inputs = cp_op.inputs | |||
| cp_ops.append(cp_op) | |||
| area.ori_op_map[cp_op] = op | |||
| #connect copied ops | |||
| for op in self.ops: | |||
| if tail_tensor in op.inputs: | |||
| op.inputs.remove(tail_tensor) | |||
| op.inputs.append(tensor_map[tail_tensor]) | |||
| tail_tensor.to_ops.remove(op) | |||
| tensor_map[tail_tensor].to_ops.append(op) | |||
| #fill cp_ops in self.recompute_area | |||
| cp_dom_op = None | |||
| for cp, ori in area.ori_op_map.items(): | |||
| if ori == area.dom_op(): | |||
| cp_dom_op = cp | |||
| area.ops.clear() | |||
| area.ops.append(cp_dom_op) | |||
| area.ops.extend([op for op in cp_ops if op != cp_dom_op]) | |||
| def __init__(self, graph, flags): | |||
| self.graph = graph | |||
| self.areas = [] | |||
| self.flags = flags | |||
| self.reach_tab = self.ReachTable(len(graph.ops)) | |||
| area_map = {} | |||
| self.enable_recompute = self.flags.get("enable_recompute_fusion", False) | |||
| self.reach_tab = self.ReachTable(len(graph.ops) + 1 if self.enable_recompute else len(graph.ops)) | |||
| self.area_map = {} | |||
| _, outputs = graph.deduce_parameters() | |||
| idx = 0 | |||
| self.idx = 0 | |||
| for op in graph.ops: | |||
| is_output = op.output in outputs | |||
| a = self.Area(op, is_output, idx, self.reach_tab) | |||
| idx += 1 | |||
| a = self.Area(op, is_output, self.idx, self.reach_tab) | |||
| self.idx += 1 | |||
| self.set_default_mode(a) | |||
| self.areas.append(a) | |||
| area_map[op] = a | |||
| self.set_area_map([op], a) | |||
| for a in self.areas: | |||
| a.link_input(area_map) | |||
| a.link_input(self.area_map) | |||
| for i in range(len(self.areas)-1, -1, -1): | |||
| self.areas[i].link_output() | |||
| if self.enable_recompute: | |||
| self.recom_area = self.Area(None, False, self.idx, self.reach_tab) | |||
| self.recom_area.is_recompute = True | |||
| self.recom_pre = None | |||
| self.recom_user = None | |||
| self.recom_dom = None | |||
| self.dom_user_r = PrimLib.UNKNOWN | |||
| self.recom_res = False | |||
| self.orig_op_map = {} | |||
| def set_area_map(self, ops, area): | |||
| """update area_map after op fused to area""" | |||
| for op in ops: | |||
| self.area_map[op] = area | |||
| def set_default_mode(self, area): | |||
| area.mode = self.get_default_mode(area.ops[0]) | |||
| @@ -234,11 +289,13 @@ class GraphSplitByPattern: | |||
| if is_forward: | |||
| for area in fuse_areas: | |||
| dominant.fuse(area) | |||
| self.set_area_map(area.ops, dominant) | |||
| self.areas.remove(area) | |||
| else: | |||
| forward_area = dominant | |||
| for area in fuse_areas: | |||
| area.fuse(forward_area) | |||
| self.set_area_map(forward_area.ops, area) | |||
| self.areas.remove(forward_area) | |||
| forward_area = area | |||
| changed = True | |||
| @@ -246,16 +303,39 @@ class GraphSplitByPattern: | |||
| else: | |||
| return changed | |||
| def to_subgraphs(self): | |||
| """Transform op groups to subgraphs""" | |||
| def fuse_recom(self, selector): | |||
| """Fuse recompute area to its user""" | |||
| for dominant in [self.recom_area, self.recom_user]: | |||
| result = selector(dominant) | |||
| if result is not None and result[0]: | |||
| fuse_areas, _ = result | |||
| fuse_areas = self.limit_area_size(dominant, fuse_areas) | |||
| if not fuse_areas: | |||
| continue | |||
| if fuse_areas[0] in [self.recom_area, self.recom_user]: | |||
| self.recom_user.fuse(self.recom_area) | |||
| self.recom_res = True | |||
| return True | |||
| return False | |||
| def index_op(self): | |||
| """index op by order, the copied op share id with original op, for topo-sort""" | |||
| ids = {} | |||
| for i, op in enumerate(self.graph.ops): | |||
| ids[op] = i | |||
| if hasattr(self, 'orig_op_map'): | |||
| for k, v in self.orig_op_map.items(): | |||
| ids[k] = ids[v] | |||
| return ids | |||
| def to_subgraphs(self): | |||
| """Transform op groups to subgraphs""" | |||
| ids = self.index_op() | |||
| subgraphs = [] | |||
| graphmodes = [] | |||
| for i, area in enumerate(self.areas): | |||
| area.ops.sort(key=lambda op: ids[op]) | |||
| subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops, area.stitch_info)) | |||
| subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops, area.stitch_info, area.recompute_ops)) | |||
| graphmodes.append("basic" if area.mode == self.Area.MODE_BASIC else "composite") | |||
| return subgraphs, graphmodes | |||
| @@ -274,13 +354,14 @@ class GraphSplitByPattern: | |||
| with os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), 'w+') as f: | |||
| f.write(subgraphs_str) | |||
| def do_split(self): | |||
| """Split graph by pattern""" | |||
| raise Exception("do_split() is not implemented in {}".format(self.__class__.__name__)) | |||
| def pattern_fuse(self, select=None): | |||
| """fuse Areas by pattern repeatedly""" | |||
| raise Exception("pattern_fuse() is not implemented in {}".format(self.__class__.__name__)) | |||
| def split(self): | |||
| """Split graph by pattern""" | |||
| self.do_split() | |||
| self.pattern_fuse() | |||
| self.recompute_fuse() | |||
| # The reshape should not be output node | |||
| # Note: after this function, the input output relation is not maintained. | |||
| self.split_output_reshapes() | |||
| @@ -316,6 +397,159 @@ class GraphSplitByPattern: | |||
| if new_areas: | |||
| self.areas += new_areas | |||
| def set_recompute(self, dom_area, ops, user_area): | |||
| """set the recompute area and connect with other areas""" | |||
| self.recom_area.recompute_ops.extend(ops) | |||
| #recom_area: set dom_op and correct ops length | |||
| patterns = [PrimLib.iter_type(op) for op in ops] | |||
| self.recom_area.pattern = max(patterns) | |||
| for i, pat in enumerate(patterns): | |||
| if pat == self.recom_area.pattern: | |||
| self.recom_area.ops = [ops[i]] * len(ops) | |||
| break | |||
| #disconnect dom_area and user_area | |||
| self.dom_user_r = dom_area.out_relations[user_area] | |||
| dom_area.out_relations.pop(user_area) | |||
| user_area.in_relations.pop(dom_area) | |||
| #connect recom_area and user_area | |||
| user_area.in_relations[self.recom_area] = self.dom_user_r | |||
| self.recom_area.out_relations[user_area] = self.dom_user_r | |||
| #connect recom_pre and recom_area | |||
| self.recom_pre = self.area_map[ops[0].inputs[0].op] if ops[0].inputs[0].op else None | |||
| if self.recom_pre is not None: | |||
| self.recom_area.in_relations[self.recom_pre] = dom_area.in_relations[self.recom_pre] | |||
| self.recom_pre.out_relations[self.recom_area] = dom_area.in_relations[self.recom_pre] | |||
| #set related areas | |||
| self.recom_user = user_area | |||
| self.recom_dom = dom_area | |||
| self.recom_res = False | |||
| def clear_recompute(self): | |||
| """disconnect recom_area from other areas, and clear recom_area""" | |||
| self.recom_area.out_relations.clear() | |||
| self.recom_area.in_relations.clear() | |||
| if not self.recom_res: | |||
| self.recom_user.in_relations.pop(self.recom_area) | |||
| self.recom_user.in_relations[self.recom_dom] = self.dom_user_r | |||
| self.recom_dom.out_relations[self.recom_user] = self.dom_user_r | |||
| if self.recom_pre: | |||
| self.recom_pre.out_relations.pop(self.recom_area) | |||
| self.recom_area.ops.clear() | |||
| self.recom_area.recompute_ops.clear() | |||
| self.orig_op_map.update(self.recom_area.ori_op_map) | |||
| self.recom_area.ori_op_map.clear() | |||
| def to_subgraph(self, dom): | |||
| """Transform area to subgraphs""" | |||
| ids = self.index_op() | |||
| dom_ops = list() | |||
| dom_ops.extend(dom.ops) | |||
| dom_ops.sort(key=lambda op: ids[op]) | |||
| subgraph = [] | |||
| subgraph = Graph('{}_area'.format(self.graph.name), dom_ops) | |||
| return subgraph | |||
| def find_cheap_regions(self, dom): | |||
| """extract all the cheap regions in dom area, toposort each region before return""" | |||
| def _grow_region(region_ops, op, weight, inputs): | |||
| """include op to region_ops if region grow""" | |||
| # region successfully ends at input | |||
| if op.inputs[0] in inputs and len(op.inputs) == 1 and \ | |||
| PrimLib.iter_type(op) <= PrimLib.BROADCAST: | |||
| region_ops.append(op) | |||
| return False, None, weight | |||
| #region fails to grow | |||
| MAX_WEIGHT = 20 | |||
| if weight > MAX_WEIGHT or len(op.inputs) > 1 or PrimLib.iter_type(op) > PrimLib.BROADCAST: | |||
| return False, None, weight | |||
| #region grows successfully | |||
| weight = weight + 1 | |||
| region_ops.append(op) | |||
| return True, op.inputs[0].op, weight | |||
| def _find_cheap_regions(dom): | |||
| sub = self.to_subgraph(dom) | |||
| inputs, outputs = sub.deduce_parameters() | |||
| cheap_regions = [] | |||
| for output in outputs: | |||
| # tensor should have user other than user_area to be fused | |||
| if output.para_type != Tensor.PARA_OUTPUT and len(output.to_ops) < 2: | |||
| continue | |||
| region_ops = [] | |||
| grow = True | |||
| candidate_op = output.op | |||
| weight = 1 | |||
| while grow: | |||
| grow, candidate_op, weight = _grow_region(region_ops, candidate_op, weight, inputs) | |||
| # region ends at input and not empty | |||
| if region_ops and region_ops[-1].inputs[0] in inputs: | |||
| region_ops.reverse() | |||
| # tensor size should equal or becomes larger(cast up, broadcast) | |||
| if region_ops[0].inputs[0].get_size() > region_ops[-1].output.get_size(): | |||
| continue | |||
| cheap_regions.append(region_ops) | |||
| return cheap_regions | |||
| return _find_cheap_regions(dom) | |||
| def select_user_area(self, tail_tensor): | |||
| """select the user area has only one edge to dom area""" | |||
| def _get_edge_num(dom_area, user_area): | |||
| """get edge num between two areas""" | |||
| dom_graph = self.to_subgraph(dom_area) | |||
| _, dom_outputs = dom_graph.deduce_parameters() | |||
| user_graph = self.to_subgraph(user_area) | |||
| user_inputs, _ = user_graph.deduce_parameters() | |||
| edge = [t for t in dom_outputs if t in user_inputs] | |||
| return len(edge) | |||
| def _select_user_area(tail_tensor): | |||
| user_areas = [] | |||
| for user_op in tail_tensor.to_ops: | |||
| user_area = self.area_map[user_op] | |||
| if len(user_area.ops) == 1 and user_area.pattern == PrimLib.RESHAPE: | |||
| continue | |||
| edge_num = _get_edge_num(self.area_map[tail_tensor.op], user_area) | |||
| if edge_num == 1 and not user_area in user_areas: | |||
| user_areas.append(user_area) | |||
| return user_areas | |||
| return _select_user_area(tail_tensor) | |||
| def recompute_fuse(self): | |||
| """find recompute regions and copy them out to new Areas""" | |||
| def do_recompute_fuse(): | |||
| """split the unfusing pattern by add recompute area""" | |||
| recompute_suc = False | |||
| orig_areas = [] | |||
| orig_areas.extend(self.areas) | |||
| for dom in orig_areas: | |||
| if dom not in self.areas or not dom.out_relations: | |||
| continue | |||
| cheap_regions = self.find_cheap_regions(dom) | |||
| dom_changed = False | |||
| for cheap_region in cheap_regions: | |||
| user_areas = self.select_user_area(cheap_region[-1].output) | |||
| if not user_areas: | |||
| continue | |||
| for user_area in user_areas: | |||
| self.set_recompute(dom, cheap_region, user_area) | |||
| self.pattern_fuse(self.fuse_recom) | |||
| self.clear_recompute() | |||
| if self.recom_res: | |||
| recompute_suc = True | |||
| #Copy region at most once for this dom | |||
| dom_changed = True | |||
| break | |||
| if dom_changed: | |||
| break | |||
| return recompute_suc | |||
| if self.enable_recompute: | |||
| while do_recompute_fuse(): | |||
| self.pattern_fuse() | |||
| use_poly_reduce = True | |||
| @@ -331,8 +565,8 @@ class GraphSplitGpu(GraphSplitByPattern): | |||
| pattern = PrimLib.iter_type(op) | |||
| return self.Area.MODE_BASIC if pattern == PrimLib.RESHAPE else self.Area.MODE_COMPOSITE | |||
| def do_split(self): | |||
| """Split graph by pattern""" | |||
| def pattern_fuse(self, fuse_func=None): | |||
| """fuse Areas by pattern""" | |||
| def _reshape(dom): | |||
| if dom.pattern != PrimLib.RESHAPE: | |||
| return None | |||
| @@ -551,21 +785,38 @@ class GraphSplitGpu(GraphSplitByPattern): | |||
| fused.append(a) | |||
| return fused, True | |||
| enable_stitch_fusion = self.flags.get("enable_stitch_fusion", False) | |||
| changed = True | |||
| while changed: | |||
| changed = self.fuse(_reshape) | |||
| changed = self.fuse(_elemwise_depth) or changed | |||
| changed = self.fuse(_elemwise_width) or changed | |||
| changed = self.fuse(_reduce_depth) or changed | |||
| changed = self.fuse(_reduce_width) or changed | |||
| changed = self.fuse(_broadcast_depth) or changed | |||
| changed = self.fuse(_broadcast_width) or changed | |||
| def _fuse_loop(): | |||
| changed = True | |||
| while changed: | |||
| changed = self.fuse(_reshape) | |||
| changed = self.fuse(_elemwise_depth) or changed | |||
| changed = self.fuse(_elemwise_width) or changed | |||
| changed = self.fuse(_reduce_depth) or changed | |||
| changed = self.fuse(_reduce_width) or changed | |||
| changed = self.fuse(_broadcast_depth) or changed | |||
| changed = self.fuse(_broadcast_width) or changed | |||
| if use_poly_reduce: | |||
| changed = self.fuse(_reduce_output) or changed | |||
| if enable_stitch_fusion: | |||
| changed = self.fuse(_reduce_stitch) or changed | |||
| self.fuse(_transpose) | |||
| def _fuse_once(fuse_func): | |||
| if fuse_func(_reshape) or fuse_func(_elemwise_depth) or fuse_func(_elemwise_width) or \ | |||
| fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \ | |||
| fuse_func(_broadcast_width): | |||
| return | |||
| if use_poly_reduce: | |||
| changed = self.fuse(_reduce_output) or changed | |||
| if enable_stitch_fusion: | |||
| changed = self.fuse(_reduce_stitch) or changed | |||
| self.fuse(_transpose) | |||
| if fuse_func(_reduce_output) or (enable_stitch_fusion and fuse_func(_reduce_stitch)): | |||
| return | |||
| fuse_func(_transpose) | |||
| return | |||
| enable_stitch_fusion = self.flags.get("enable_stitch_fusion", False) | |||
| if fuse_func is None: | |||
| _fuse_loop() | |||
| else: | |||
| _fuse_once(fuse_func) | |||
| class GraphSplitAscend(GraphSplitByPattern): | |||
| @@ -580,8 +831,8 @@ class GraphSplitAscend(GraphSplitByPattern): | |||
| return self.Area.MODE_COMPOSITE | |||
| return self.Area.MODE_BASIC | |||
| def do_split(self): | |||
| """Split graph by pattern""" | |||
| def pattern_fuse(self, fuse_func=None): | |||
| """fuse Areas by pattern""" | |||
| def _tensor_size(tensor): | |||
| size = 1 | |||
| for i in tensor.shape: | |||
| @@ -685,6 +936,19 @@ class GraphSplitAscend(GraphSplitByPattern): | |||
| fused.append(a) | |||
| return fused, False | |||
| def _reduce_output(dom): | |||
| if dom.pattern != PrimLib.REDUCE: | |||
| return None | |||
| op_attrs = dom.dom_op().attrs | |||
| if not op_attrs.get('reduce_output_fuse'): | |||
| return None | |||
| fused = [] | |||
| for a, r in dom.out_relations.items(): | |||
| if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \ | |||
| dom.check_acyclic(a): | |||
| fused.append(a) | |||
| return fused, False | |||
| def _transdata_pattern_support(dom, a): | |||
| transdata_op = dom.dom_op() | |||
| @@ -733,32 +997,31 @@ class GraphSplitAscend(GraphSplitByPattern): | |||
| fused.append(a) | |||
| return fused, True | |||
| def _reduce_output(dom): | |||
| if dom.pattern != PrimLib.REDUCE: | |||
| return None | |||
| op_attrs = dom.dom_op().attrs | |||
| if not op_attrs.get('reduce_output_fuse'): | |||
| return None | |||
| fused = [] | |||
| for a, r in dom.out_relations.items(): | |||
| if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \ | |||
| dom.check_acyclic(a): | |||
| fused.append(a) | |||
| return fused, False | |||
| changed = True | |||
| while changed: | |||
| changed = self.fuse(_reshape) | |||
| changed = self.fuse(_elemwise_depth) or changed | |||
| changed = self.fuse(_elemwise_width) or changed | |||
| changed = self.fuse(_reduce_depth) or changed | |||
| changed = self.fuse(_reduce_width) or changed | |||
| changed = self.fuse(_broadcast_depth) or changed | |||
| changed = self.fuse(_broadcast_width) or changed | |||
| changed = self.fuse(_matmul_depth) or changed | |||
| changed = self.fuse(_reduce_output) or changed | |||
| self.fuse(_transdata) | |||
| def _fuse_loop(): | |||
| changed = True | |||
| while changed: | |||
| changed = self.fuse(_reshape) | |||
| changed = self.fuse(_elemwise_depth) or changed | |||
| changed = self.fuse(_elemwise_width) or changed | |||
| changed = self.fuse(_reduce_depth) or changed | |||
| changed = self.fuse(_reduce_width) or changed | |||
| changed = self.fuse(_broadcast_depth) or changed | |||
| changed = self.fuse(_broadcast_width) or changed | |||
| changed = self.fuse(_matmul_depth) or changed | |||
| changed = self.fuse(_reduce_output) or changed | |||
| self.fuse(_transdata) | |||
| def _fuse_once(fuse_func): | |||
| if fuse_func(_reshape) or fuse_func(_elemwise_depth) or fuse_func(_elemwise_width) or \ | |||
| fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \ | |||
| fuse_func(_broadcast_width) or fuse_func(_matmul_depth) or fuse_func(_reduce_output) or \ | |||
| fuse_func(_transdata): | |||
| pass | |||
| if fuse_func is None: | |||
| _fuse_loop() | |||
| else: | |||
| _fuse_once(fuse_func) | |||
| def split(graph, target, flags): | |||
| @@ -320,8 +320,8 @@ class Operator: | |||
| def __str__(self): | |||
| args = ', '.join([str(t) for t in self.all_inputs]) | |||
| expr = "%s = %s.%s(%s)" % ( | |||
| str(self.output), self.prim, self.output.dtype, args) | |||
| expr = "%s = %s.%s(%s) id:%s" % ( | |||
| str(self.output), self.prim, self.output.dtype, args, id(self)) | |||
| return expr if not self.attrs else '%s // %s' % (expr, str(self.attrs)) | |||
| def __repr__(self): | |||
| @@ -331,12 +331,13 @@ class Operator: | |||
| class Graph: | |||
| """Graph""" | |||
| def __init__(self, name, ops, stitch_info=None): | |||
| def __init__(self, name, ops, stitch_info=None, recompute_ops=None): | |||
| self.name = name | |||
| self.ops = ops # in topo order, can not use set | |||
| self.inputs = [] | |||
| self.outputs = [] | |||
| self.stitch_info = stitch_info | |||
| self.recompute_ops = recompute_ops | |||
| def set_processor(self, processor): | |||
| """Set processor""" | |||
| @@ -203,11 +203,13 @@ class CompositeGraph: | |||
| desc['buffer_stitch'] = buffer_stitch | |||
| return desc | |||
| def dump(self, subgraph): | |||
| """Dump Graph to json""" | |||
| desc = {} | |||
| inputs, outputs = subgraph.deduce_parameters() | |||
| graph_ops = set(subgraph.ops) | |||
| def add_recompute_ops(self, subgraph, desc): | |||
| if subgraph.recompute_ops: | |||
| desc['recompute_ops'] = [op.output.name for op in subgraph.recompute_ops] | |||
| return desc | |||
| def _pre_dump(self, outputs): | |||
| """restore name to before load""" | |||
| inplace_assign = {} # y_name, output_name | |||
| inplace_assign_z = None | |||
| for op in self.desc['op_desc']: | |||
| @@ -217,6 +219,14 @@ class CompositeGraph: | |||
| for t in outputs: | |||
| if t.name not in inplace_assign: | |||
| inplace_assign_z = t | |||
| return inplace_assign, inplace_assign_z | |||
| def dump(self, subgraph): | |||
| """Dump Graph to json""" | |||
| desc = {} | |||
| inputs, outputs = subgraph.deduce_parameters() | |||
| graph_ops = set(subgraph.ops) | |||
| inplace_assign, inplace_assign_z = self._pre_dump(outputs) | |||
| for key in self.desc: | |||
| if key == 'input_desc': | |||
| desc[key] = [ | |||
| @@ -251,7 +261,7 @@ class CompositeGraph: | |||
| op_desc.append(inplace_desc) | |||
| else: | |||
| op = self.tensors[d['output_desc'][0]['tensor_name']].op | |||
| if op in graph_ops: | |||
| if op in graph_ops or op in subgraph.recompute_ops: | |||
| op_desc.append(d) | |||
| desc[key] = op_desc | |||
| elif key == 'op': | |||
| @@ -260,6 +270,7 @@ class CompositeGraph: | |||
| desc[key] = self.desc[key] | |||
| desc = self.add_stitch_info(subgraph, desc) | |||
| desc = self.add_recompute_ops(subgraph, desc) | |||
| return desc | |||
| @@ -433,68 +433,5 @@ FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const std::string &kernel_js | |||
| auto kernel_json = nlohmann::json::parse(kernel_json_str); | |||
| return DecodeFusedNodes(kernel_json); | |||
| } | |||
| StitchInfo AkgKernelJsonDecoder::GetStitchInfo(const nlohmann::json &kernel_json) const { | |||
| StitchInfo info; | |||
| if (kernel_json.find(kJsonKeyBufferStitch) != kernel_json.end()) { | |||
| nlohmann::json buffer_stitch = kernel_json[kJsonKeyBufferStitch]; | |||
| if (buffer_stitch.find(kJsonKeyStitchOp) != buffer_stitch.end()) { | |||
| std::vector<std::string> stitch_ops = buffer_stitch[kJsonKeyStitchOp]; | |||
| info.stitch_ops = stitch_ops; | |||
| } | |||
| if (buffer_stitch.find(kJsonKeyStitchAtomicOp) != buffer_stitch.end()) { | |||
| std::vector<std::string> stitch_atomic_ops = buffer_stitch[kJsonKeyStitchAtomicOp]; | |||
| info.stitch_atomic_ops = stitch_atomic_ops; | |||
| } | |||
| } | |||
| return info; | |||
| } | |||
| void AkgKernelJsonDecoder::SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, | |||
| const CNodePtr &node) const { | |||
| std::vector<nlohmann::json> output_descs = op_desc[kJsonKeyOutputDesc]; | |||
| if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) return; | |||
| std::string tensor_name = output_descs[0][kJsonKeyTensorName]; | |||
| if (std::find(info.stitch_ops.begin(), info.stitch_ops.end(), tensor_name) != info.stitch_ops.end()) { | |||
| AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("common"), node); | |||
| MS_LOG(INFO) << "Enable common stitch fusion by " << node->fullname_with_scope(); | |||
| } | |||
| if (std::find(info.stitch_atomic_ops.begin(), info.stitch_atomic_ops.end(), tensor_name) != | |||
| info.stitch_atomic_ops.end()) { | |||
| AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("atomic"), node); | |||
| MS_LOG(INFO) << "Enable atomic add stitch fusion by " << node->fullname_with_scope(); | |||
| } | |||
| } | |||
| bool AkgKernelJsonDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json, | |||
| const std::map<std::string, AnfNodePtr> &address_node_map, | |||
| AnfNodePtrList *res_graphs) { | |||
| MS_EXCEPTION_IF_NULL(res_graphs); | |||
| MS_LOG(DEBUG) << "start decode, " << kernel_json; | |||
| // decode cnodes in graph. | |||
| std::vector<nlohmann::json> op_node_descs = kernel_json[kJsonKeyOpDesc]; | |||
| if (op_node_descs.empty()) { | |||
| MS_LOG(ERROR) << "Error decode, no cnodes for graph." << kernel_json; | |||
| return false; | |||
| } | |||
| StitchInfo info = GetStitchInfo(kernel_json); | |||
| for (const auto &op_desc : op_node_descs) { | |||
| if (op_desc.find(kJsonKeyPtrAddress) == op_desc.end() || op_desc[kJsonKeyPtrAddress].is_null()) { | |||
| MS_LOG(ERROR) << "Decode failed, key: " << kJsonKeyPtrAddress << " not found in: " << op_desc; | |||
| return false; | |||
| } | |||
| std::string ptr_address = op_desc[kJsonKeyPtrAddress]; | |||
| if (address_node_map.count(ptr_address) == 0) { | |||
| MS_LOG(ERROR) << "Decode failed, ptr_address not found in map."; | |||
| return false; | |||
| } | |||
| auto node = address_node_map.at(ptr_address)->cast<CNodePtr>(); | |||
| SetStitchAttr(op_desc, info, node); | |||
| res_graphs->push_back(node); | |||
| } | |||
| MS_LOG(DEBUG) << "decode cnodes success, size: " << res_graphs->size(); | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -26,10 +26,6 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| struct StitchInfo { | |||
| std::vector<std::string> stitch_ops; | |||
| std::vector<std::string> stitch_atomic_ops; | |||
| }; | |||
| class AkgKernelJsonDecoder { | |||
| public: | |||
| AkgKernelJsonDecoder() { nodes_map_.clear(); } | |||
| @@ -37,15 +33,11 @@ class AkgKernelJsonDecoder { | |||
| FuncGraphPtr DecodeFusedNodes(const nlohmann::json &kernel_json); | |||
| FuncGraphPtr DecodeFusedNodes(const std::string &kernel_json_str); | |||
| bool DecodeSplitNodes(const nlohmann::json &kernel_json, const std::map<std::string, AnfNodePtr> &address_node_map, | |||
| AnfNodePtrList *res_graphs); | |||
| private: | |||
| ParameterPtr DecodeParameter(const nlohmann::json ¶meter_json, const FuncGraphPtr &func_graph); | |||
| CNodePtr DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, const std::string &processor); | |||
| AnfNodePtr DecodeOutput(const std::vector<nlohmann::json> &output_descs, const FuncGraphPtr &func_graph); | |||
| StitchInfo GetStitchInfo(const nlohmann::json &kernel_json) const; | |||
| void SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node) const; | |||
| std::map<std::string, AnfNodePtr> nodes_map_; | |||
| }; | |||
| } // namespace kernel | |||
| @@ -54,6 +54,7 @@ constexpr auto kJsonKeyFusionType = "fusion_type"; | |||
| constexpr auto kJsonKeySubGraph = "sub_graph"; | |||
| constexpr auto kJsonKeyCoreNum = "core_num"; | |||
| constexpr auto kJsonKeyTypeInfo = "type_info"; | |||
| constexpr auto kJsonKeyRecomputeOps = "recompute_ops"; | |||
| constexpr auto kJsonKeyBufferStitch = "buffer_stitch"; | |||
| constexpr auto kJsonKeyStitchOp = "stitch_op"; | |||
| constexpr auto kJsonKeyStitchAtomicOp = "stitch_atomic_op"; | |||
| @@ -117,7 +117,7 @@ PassManagerPtr GraphKernelOptimizer::Split() const { | |||
| // which can avoid unnecessary input-output and get better performance. | |||
| // preprocess for ShapeOpsSplitter | |||
| pm->AddPass(std::make_shared<ExtendOutputForUpdateState>(), OptLevel_1); | |||
| std::vector<PrimitivePtr> duplicated_ops = {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimCast}; | |||
| std::vector<PrimitivePtr> duplicated_ops = {prim::kPrimReshape}; | |||
| pm->AddPass(std::make_shared<ShapeOpsSplitter>(duplicated_ops), OptLevel_1); | |||
| // Split kernel according to costmodel | |||
| @@ -32,6 +32,150 @@ | |||
| #include "utils/context/graph_kernel_flags.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| namespace { | |||
| StitchInfo GetStitchInfo(const nlohmann::json &kernel_json) { | |||
| StitchInfo info; | |||
| if (kernel_json.find(kJsonKeyBufferStitch) != kernel_json.end()) { | |||
| nlohmann::json buffer_stitch = kernel_json[kJsonKeyBufferStitch]; | |||
| if (buffer_stitch.find(kJsonKeyStitchOp) != buffer_stitch.end()) { | |||
| std::vector<std::string> stitch_ops = buffer_stitch[kJsonKeyStitchOp]; | |||
| info.stitch_ops = stitch_ops; | |||
| } | |||
| if (buffer_stitch.find(kJsonKeyStitchAtomicOp) != buffer_stitch.end()) { | |||
| std::vector<std::string> stitch_atomic_ops = buffer_stitch[kJsonKeyStitchAtomicOp]; | |||
| info.stitch_atomic_ops = stitch_atomic_ops; | |||
| } | |||
| } | |||
| return info; | |||
| } | |||
| std::set<std::string> GetRecomputeOps(const nlohmann::json &kernel_json) { | |||
| if (kernel_json.find(kJsonKeyRecomputeOps) != kernel_json.end()) { | |||
| std::vector<std::string> recompute_ops = kernel_json[kJsonKeyRecomputeOps]; | |||
| return std::set<std::string>(recompute_ops.begin(), recompute_ops.end()); | |||
| } | |||
| return std::set<std::string>(); | |||
| } | |||
| bool IsRecomputeOp(const nlohmann::json &op_desc, const std::set<std::string> &recompute_ops) { | |||
| std::vector<nlohmann::json> output_descs = op_desc[kJsonKeyOutputDesc]; | |||
| if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) { | |||
| return false; | |||
| } | |||
| std::string tensor_name = output_descs[0][kJsonKeyTensorName]; | |||
| if (recompute_ops.count(tensor_name)) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| CNodePtr NewRecomputeNode(const AnfNodePtr &orig_node, std::map<AnfNodePtr, AnfNodePtr> *node_map) { | |||
| auto func_graph = orig_node->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto cnode = orig_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| TraceGuard guard(std::make_shared<TraceOpt>(cnode->debug_info())); | |||
| auto orig_inputs = cnode->inputs(); | |||
| std::vector<AnfNodePtr> inputs; | |||
| for (auto inp : orig_inputs) { | |||
| if (node_map->find(inp) == node_map->end()) { | |||
| inputs.push_back(inp); | |||
| continue; | |||
| } | |||
| inputs.push_back((*node_map)[inp]); | |||
| } | |||
| CNodePtr cp_node = func_graph->NewCNode(inputs); | |||
| func_graph->AddNode(cp_node); | |||
| cp_node->set_abstract(cnode->abstract()); | |||
| cp_node->set_forward(cnode->forward().first, cnode->forward().second); | |||
| cp_node->set_inputs_value(cnode->inputs_value()); | |||
| ScopePtr scope = (orig_node->scope() != kDefaultScope) ? orig_node->scope() : kDefaultScope; | |||
| cp_node->set_scope(scope); | |||
| cp_node->set_kernel_info(cnode->kernel_info_ptr()); | |||
| (*node_map)[orig_node] = cp_node; | |||
| return cp_node->cast<CNodePtr>(); | |||
| } | |||
| void SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node) { | |||
| std::vector<nlohmann::json> output_descs = op_desc[kJsonKeyOutputDesc]; | |||
| if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) return; | |||
| std::string tensor_name = output_descs[0][kJsonKeyTensorName]; | |||
| if (std::find(info.stitch_ops.begin(), info.stitch_ops.end(), tensor_name) != info.stitch_ops.end()) { | |||
| AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("common"), node); | |||
| MS_LOG(INFO) << "Enable common stitch fusion by " << node->fullname_with_scope(); | |||
| } | |||
| if (std::find(info.stitch_atomic_ops.begin(), info.stitch_atomic_ops.end(), tensor_name) != | |||
| info.stitch_atomic_ops.end()) { | |||
| AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("atomic"), node); | |||
| MS_LOG(INFO) << "Enable atomic add stitch fusion by " << node->fullname_with_scope(); | |||
| } | |||
| } | |||
| // replace original region root op by its copy in this res_graphs | |||
| void ConnectRecomputeOps(AnfNodePtrList *res_graphs, const AnfNodePtr &orig_region_root, | |||
| const AnfNodePtr &cp_region_root) { | |||
| for (auto &node : *res_graphs) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto inputs = cnode->inputs(); | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| if (inputs[i] != orig_region_root) continue; | |||
| cnode->set_input(i, cp_region_root); | |||
| } | |||
| } | |||
| } | |||
| } // namespace | |||
| bool SplitNodesDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json, | |||
| const std::map<std::string, AnfNodePtr> &address_node_map, | |||
| AnfNodePtrList *res_graphs) { | |||
| MS_EXCEPTION_IF_NULL(res_graphs); | |||
| MS_LOG(DEBUG) << "start decode, " << kernel_json; | |||
| // decode cnodes in graph. | |||
| std::vector<nlohmann::json> op_node_descs = kernel_json[kJsonKeyOpDesc]; | |||
| if (op_node_descs.empty()) { | |||
| MS_LOG(ERROR) << "Error decode, no cnodes for graph." << kernel_json; | |||
| return false; | |||
| } | |||
| StitchInfo info = GetStitchInfo(kernel_json); | |||
| auto recompute_ops = GetRecomputeOps(kernel_json); | |||
| // key_value: original_copied | |||
| std::map<AnfNodePtr, AnfNodePtr> node_map; | |||
| // nodes would be copied | |||
| AnfNodePtrList orig_region_nodes; | |||
| // nodes would not be copied | |||
| AnfNodePtrList no_cp_nodes; | |||
| for (const auto &op_desc : op_node_descs) { | |||
| if (op_desc.find(kJsonKeyPtrAddress) == op_desc.end() || op_desc[kJsonKeyPtrAddress].is_null()) { | |||
| MS_LOG(ERROR) << "Decode failed, key: " << kJsonKeyPtrAddress << " not found in: " << op_desc; | |||
| return false; | |||
| } | |||
| std::string ptr_address = op_desc[kJsonKeyPtrAddress]; | |||
| if (address_node_map.count(ptr_address) == 0) { | |||
| MS_LOG(ERROR) << "Decode failed, ptr_address not found in map."; | |||
| return false; | |||
| } | |||
| auto node = address_node_map.at(ptr_address)->cast<CNodePtr>(); | |||
| if (IsRecomputeOp(op_desc, recompute_ops)) { | |||
| auto cp_node = NewRecomputeNode(node, &node_map); | |||
| orig_region_nodes.push_back(node); | |||
| SetStitchAttr(op_desc, info, cp_node); | |||
| res_graphs->push_back(cp_node); | |||
| continue; | |||
| } | |||
| SetStitchAttr(op_desc, info, node); | |||
| res_graphs->push_back(node); | |||
| no_cp_nodes.push_back(node); | |||
| } | |||
| for (auto orig_node : orig_region_nodes) { | |||
| ConnectRecomputeOps(&no_cp_nodes, orig_node, node_map[orig_node]); | |||
| } | |||
| MS_LOG(DEBUG) << "decode cnodes success, size: " << res_graphs->size(); | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| namespace opt { | |||
| namespace { | |||
| void TraverseFuncGraphFromCNode(const CNodePtr &cnode, const std::function<void(AnfNodePtr &)> &callback) { | |||
| @@ -620,7 +764,7 @@ class CostModelSplitSchemer : public SplitSchemer { | |||
| split_plan_.clear(); | |||
| for (const auto &graph_desc : graph_descs) { | |||
| AnfNodePtrList res_graph; | |||
| if (!akg_kernel_json_decoder.DecodeSplitNodes(graph_desc, address_node_map, &res_graph)) { | |||
| if (!kernel::SplitNodesDecoder::DecodeSplitNodes(graph_desc, address_node_map, &res_graph)) { | |||
| MS_LOG(ERROR) << "Failed decode sub graph, " << graph_desc; | |||
| return false; | |||
| } | |||
| @@ -731,6 +875,7 @@ class CostModelSplitSchemer : public SplitSchemer { | |||
| nlohmann::json flag_json; | |||
| flag_json["dump_as_text"] = flags.dump_as_text; | |||
| flag_json["enable_stitch_fusion"] = flags.enable_stitch_fusion; | |||
| flag_json["enable_recompute_fusion"] = flags.enable_recompute_fusion; | |||
| return flag_json.dump(); | |||
| } | |||
| @@ -15,11 +15,31 @@ | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_SPLITTER_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_SPLITTER_H_ | |||
| #include <map> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <nlohmann/json.hpp> | |||
| #include "ir/func_graph.h" | |||
| #include "backend/optimizer/common/pass.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| struct StitchInfo { | |||
| std::vector<std::string> stitch_ops; | |||
| std::vector<std::string> stitch_atomic_ops; | |||
| }; | |||
| class SplitNodesDecoder { | |||
| public: | |||
| SplitNodesDecoder() {} | |||
| ~SplitNodesDecoder() = default; | |||
| static bool DecodeSplitNodes(const nlohmann::json &kernel_json, | |||
| const std::map<std::string, AnfNodePtr> &address_node_map, AnfNodePtrList *res_graphs); | |||
| }; | |||
| } // namespace kernel | |||
| namespace opt { | |||
| class GraphKernelSplitter : public Pass { | |||
| public: | |||
| @@ -181,6 +181,7 @@ void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_ma | |||
| // Boolean flags | |||
| reg.AddFlag("dump_as_text", &dump_as_text); | |||
| reg.AddFlag("enable_stitch_fusion", &enable_stitch_fusion, opt_level == OptLevel_3); | |||
| reg.AddFlag("enable_recompute_fusion", &enable_recompute_fusion, opt_level == OptLevel_2); | |||
| reg.AddFlag("enable_parallel_fusion", &enable_parallel_fusion, opt_level == OptLevel_3); | |||
| // Integer flags | |||
| @@ -203,6 +204,7 @@ std::string GraphKernelFlags::DumpAllFlags() const { | |||
| json["dump_as_text"] = dump_as_text; | |||
| json["enable_stitch_fusion"] = enable_stitch_fusion; | |||
| json["enable_recompute_fusion"] = enable_recompute_fusion; | |||
| json["enable_parallel_fusion"] = enable_parallel_fusion; | |||
| json["opt_level"] = opt_level; | |||
| @@ -67,6 +67,11 @@ class GraphKernelFlags { | |||
| */ | |||
| bool enable_stitch_fusion; | |||
| /** | |||
| * Enable recompute fusion in graph kernel fusion strategy. | |||
| */ | |||
| bool enable_recompute_fusion{true}; | |||
| /** | |||
| * Enable parallel fusion in graph kernel fusion strategy. | |||
| * | |||
| @@ -0,0 +1,223 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| from mindspore import Tensor | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.nn import Cell | |||
| import mindspore.ops.operations as P | |||
| #{cast} would be recompute and fused | |||
| class Net1(Cell): | |||
| def __init__(self): | |||
| super(Net1, self).__init__() | |||
| self.cast = P.Cast() | |||
| self.sum = P.ReduceSum(keep_dims=False) | |||
| def construct(self, x): | |||
| cast_res = self.cast(x, mstype.float32) | |||
| sum1_res = self.sum(cast_res, (0,)) | |||
| sum2_res = self.sum(cast_res, (1,)) | |||
| return sum1_res, sum2_res | |||
| #{sqrt} would be recompute on Ascend | |||
| class Net2(Cell): | |||
| def __init__(self): | |||
| super(Net2, self).__init__() | |||
| self.sqrt = P.Sqrt() | |||
| self.sum = P.ReduceSum(keep_dims=True) | |||
| self.add = P.Add() | |||
| self.neg = P.Neg() | |||
| def construct(self, x0, x1): | |||
| sqrt_res = self.sqrt(x0) | |||
| neg_res = self.neg(sqrt_res) | |||
| add_res = self.add(x1, sqrt_res) | |||
| sum_res = self.sum(add_res, (0,)) | |||
| return neg_res, sum_res | |||
| #{sqrt} would be recompute | |||
| class Net3(Cell): | |||
| def __init__(self): | |||
| super(Net3, self).__init__() | |||
| self.sqrt = P.Sqrt() | |||
| self.add = P.Add() | |||
| self.neg = P.Neg() | |||
| def construct(self, x0, x1): | |||
| sqrt_res = self.sqrt(x0) | |||
| neg_res = self.neg(sqrt_res) | |||
| add_res = self.add(x1, sqrt_res) | |||
| return neg_res, add_res | |||
| #{sqrt neg} would be recompute | |||
| class Net4(Cell): | |||
| def __init__(self): | |||
| super(Net4, self).__init__() | |||
| self.sqrt = P.Sqrt() | |||
| self.neg = P.Neg() | |||
| self.sum = P.ReduceSum(keep_dims=False) | |||
| def construct(self, x): | |||
| sqrt_res = self.sqrt(x) | |||
| neg_res = self.neg(sqrt_res) | |||
| sum1_res = self.sum(neg_res, (0,)) | |||
| sum2_res = self.sum(neg_res, (1,)) | |||
| return sum1_res, sum2_res | |||
| #{sqrt} would be recompute | |||
| class Net5(Cell): | |||
| def __init__(self): | |||
| super(Net5, self).__init__() | |||
| self.sqrt = P.Sqrt() | |||
| self.add = P.Add() | |||
| def construct(self, x0, x1, x2): | |||
| sqrt_res = self.sqrt(x0) | |||
| add1_res = self.add(sqrt_res, x1) | |||
| add2_res = self.add(sqrt_res, x2) | |||
| return add1_res, add2_res | |||
| def test_basic1(net): | |||
| def get_output(i0, net, enable_graph_kernel=False): | |||
| context.set_context(enable_graph_kernel=enable_graph_kernel) | |||
| net_obj = net() | |||
| output = net_obj(i0) | |||
| return output | |||
| i0 = Tensor(np.random.uniform(1, 2, [1024, 1024]).astype(np.float16)) | |||
| expect = get_output(i0, net, False) | |||
| output = get_output(i0, net, True) | |||
| expect0_np = expect[0].asnumpy().copy() | |||
| output0_np = output[0].asnumpy().copy() | |||
| expect1_np = expect[1].asnumpy().copy() | |||
| output1_np = output[1].asnumpy().copy() | |||
| assert np.allclose(expect0_np, output0_np, 1.e-3, 1.e-3) | |||
| assert np.allclose(expect1_np, output1_np, 1.e-3, 1.e-3) | |||
| def test_basic2(net): | |||
| def get_output(i0, i1, net, enable_graph_kernel=False): | |||
| context.set_context(enable_graph_kernel=enable_graph_kernel) | |||
| net_obj = net() | |||
| output = net_obj(i0, i1) | |||
| return output | |||
| i0 = Tensor(np.random.uniform(1, 2, [1, 1024]).astype(np.float32)) | |||
| i1 = Tensor(np.random.uniform(1, 2, [1024, 1024]).astype(np.float32)) | |||
| expect = get_output(i0, i1, net, False) | |||
| output = get_output(i0, i1, net, True) | |||
| expect0_np = expect[0].asnumpy().copy() | |||
| output0_np = output[0].asnumpy().copy() | |||
| expect1_np = expect[1].asnumpy().copy() | |||
| output1_np = output[1].asnumpy().copy() | |||
| assert np.allclose(expect0_np, output0_np, 1.e-3, 1.e-3) | |||
| assert np.allclose(expect1_np, output1_np, 1.e-3, 1.e-3) | |||
| def test_basic3(net): | |||
| def get_output(i0, i1, i2, net, enable_graph_kernel=False): | |||
| context.set_context(enable_graph_kernel=enable_graph_kernel) | |||
| net_obj = net() | |||
| output = net_obj(i0, i1, i2) | |||
| return output | |||
| i0 = Tensor(np.random.uniform(1, 2, [1, 1024]).astype(np.float16)) | |||
| i1 = Tensor(np.random.uniform(1, 2, [1024, 1024]).astype(np.float16)) | |||
| i2 = Tensor(np.random.uniform(1, 2, [2048, 1024]).astype(np.float16)) | |||
| expect = get_output(i0, i1, i2, net, False) | |||
| output = get_output(i0, i1, i2, net, True) | |||
| expect0_np = expect[0].asnumpy().copy() | |||
| output0_np = output[0].asnumpy().copy() | |||
| expect1_np = expect[1].asnumpy().copy() | |||
| output1_np = output[1].asnumpy().copy() | |||
| assert np.allclose(expect0_np, output0_np, 1.e-3, 1.e-3) | |||
| assert np.allclose(expect1_np, output1_np, 1.e-3, 1.e-3) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_gpu_1(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| test_basic1(Net1) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_gpu_2(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| test_basic2(Net2) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_gpu_3(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| test_basic2(Net3) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_gpu_4(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| test_basic1(Net4) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_gpu_5(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| test_basic3(Net5) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_ascend_1(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| test_basic1(Net1) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_ascend_2(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| test_basic2(Net2) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_ascend_3(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| test_basic2(Net3) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_ascend_4(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| test_basic1(Net4) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_ascend_5(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| test_basic3(Net5) | |||