#!/usr/bin/env python3 # coding: utf-8 # Copyright 2019 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ custom tiling function """ from enum import Enum, unique from functools import wraps from numpy.core import double import akg from akg import dim from akg.utils.validation_check import check_input_type set_dim_func_map = {} gen_key_func_map = {} NODE_TYPE = "CustomTilingNode" DEFAULT_VALUE = -1 DEFAULT_STRING = "" BLOCK_SIZE = 32 CUBE_UNIT = 16 class TileTemplate(Enum): """class TileTemplate.""" NC1HWC0 = "NC1HWC0" NCHW = "NCHW" DEFAULT_FORMAT = "NCHW" NHWC = "NHWC" @unique class TileLevel(Enum): """class TileLevel.""" C1 = "C1" C0 = "C0" @unique class TileMode(Enum): """class TileMode.""" AXIS = "AXIS" TENSOR = "TENSOR" COMMON = "COMMON" @unique class TileConstraint(Enum): """class TileConstraint.""" MIN = "MIN" MOD = "MOD" MAX = "MAX" FACTOR = "FACTOR" CANDIDATE = "CANDIDATE" FORBID_ISOLATE = "FORBID_ISOLATE" SET_PRIORITY = "SET_PRIORITY" SET_EXPANSION = "SET_EXPANSION" SET_MEM_RATIO = "SET_MEM_RATIO" SET_AXIS_INFO = "SET_AXIS_INFO" THREAD_MIN = "THREAD_MIN" THREAD_MAX = "THREAD_MAX" THREAD_MOD = "THREAD_MOD" BLOCK_MIN = "BLOCK_MIN" BLOCK_MAX = "BLOCK_MAX" BLOCK_MOD = "BLOCK_MOD" @check_input_type((double, float, int, list), TileConstraint, TileLevel) def modify_common_constraints(value, constraint, level=TileLevel.C1): """api for dsl to modify some default constraint used in auto tiling.""" if constraint not in TileConstraint: raise ValueError("Tile constraints must be chosen from {0}".format(TileConstraint)) if constraint == TileConstraint.SET_MEM_RATIO: return create_custom_tiling_node(TileMode.COMMON, tile_level=level, mem_ratio=double(value)) if constraint == TileConstraint.THREAD_MIN: return create_custom_tiling_node(TileMode.COMMON, thread_min=value) if constraint == TileConstraint.THREAD_MAX: return create_custom_tiling_node(TileMode.COMMON, thread_max=value) if constraint == TileConstraint.THREAD_MOD: return create_custom_tiling_node(TileMode.COMMON, thread_mod=value) if constraint == TileConstraint.BLOCK_MIN: return create_custom_tiling_node(TileMode.COMMON, block_min=value) if constraint == TileConstraint.BLOCK_MAX: return create_custom_tiling_node(TileMode.COMMON, block_max=value) if constraint == TileConstraint.BLOCK_MOD: return create_custom_tiling_node(TileMode.COMMON, block_mod=value) raise TypeError("Constraint {} is not supported in this api, please use other api" .format(constraint.value)) @check_input_type((str, int), TileConstraint, int, (int, list, tuple, type(None)), TileLevel) def create_constraint_on_axis(values, constraints, band=0, axis=None, level=TileLevel.C1): """api for dsl to create tiling constraints on certain axis.""" if constraints not in TileConstraint: raise ValueError("Tile constraints must be chosen from {0}".format(TileConstraint)) res = [] if axis is None: axis = [i for i in range(len(values))] elif not isinstance(axis, (int, list, tuple)): raise TypeError("Axis should be int, list or tuple") if isinstance(axis, int): axis = [axis] if isinstance(values, (str, int)): values = [values] else: raise TypeError("Tiling factor must be string or int, while receives {}".format(type(values))) if len(axis) != len(values): raise ValueError("Length of axis must equal to length of values") for a, v in zip(axis, values): if constraints == TileConstraint.MIN: res.append(create_custom_tiling_node(TileMode.AXIS, tile_level=level, tile_band=band, tile_axis=a, tile_min=v)) elif constraints == TileConstraint.MOD: res.append(create_custom_tiling_node(TileMode.AXIS, tile_level=level, tile_band=band, tile_axis=a, tile_mod=v)) elif constraints == TileConstraint.FACTOR: res.append(create_custom_tiling_node(TileMode.AXIS, tile_level=level, tile_band=band, tile_axis=a, tile_factor=v)) elif constraints == TileConstraint.CANDIDATE: res.append(create_custom_tiling_node(TileMode.AXIS, tile_level=level, tile_band=band, tile_axis=a, tile_candidate=v)) elif constraints == TileConstraint.MAX: res.append(create_custom_tiling_node(TileMode.AXIS, tile_level=level, tile_band=band, tile_axis=a, tile_max=v)) elif constraints == TileConstraint.FORBID_ISOLATE: res.append(create_custom_tiling_node(TileMode.AXIS, tile_level=level, tile_band=band, tile_axis=a, forbid_isolate=v)) elif constraints == TileConstraint.SET_AXIS_INFO: res.append(create_custom_tiling_node(TileMode.AXIS, tile_level=level, tile_band=band, tile_axis=a, axis_info=v)) elif constraints == TileConstraint.SET_PRIORITY: res.append(create_custom_tiling_node(TileMode.AXIS, tile_level=level, tile_band=band, tile_axis=a, priority=v)) else: raise TypeError("Constraint {} is not supported in this api, please use other api" .format(constraints.value)) return res @check_input_type((akg.tvm.tensor.Tensor, list, tuple), (str, int, list, tuple), TileConstraint, (int, list, tuple, type(None)), TileLevel) def create_constraint_on_tensor(tensor, values, constraints, tensor_pos=None, level=TileLevel.C1): """api for dsl to create tiling constraints on certain tensor.""" if constraints not in TileConstraint: raise ValueError("Tile constraint must be chosen from {0}".format(TileConstraint)) if isinstance(tensor, (list, tuple)): for t in tensor: if not isinstance(t, akg.tvm.tensor.Tensor): raise TypeError("Tensor should be tvm.tensor.Tensor or a list/tuple of tvm.tensor.Tensor.") tensor_name = [tensor.op.name] if isinstance(tensor, akg.tvm.tensor.Tensor) else [t.op.name for t in tensor] values = [values] if isinstance(values, (str, int)) else values if tensor_pos is None: tensor_pos = [i for i in range(len(values))] else: tensor_pos = [tensor_pos] if isinstance(tensor_pos, int) else tensor_pos if len(tensor_pos) != len(values): raise ValueError("Length of tensor position is not compatible with length of constraint values") strategy = list() for t in tensor_name: for p, v in zip(tensor_pos, values): if constraints == TileConstraint.MIN: strategy.append(create_custom_tiling_node(TileMode.TENSOR, tile_level=level, tensor_name=t, tile_pos=p, tile_min=v)) elif constraints == TileConstraint.MOD: strategy.append(create_custom_tiling_node(TileMode.TENSOR, tile_level=level, tensor_name=t, tile_pos=p, tile_mod=v)) elif constraints == TileConstraint.FACTOR: strategy.append(create_custom_tiling_node(TileMode.TENSOR, tile_level=level, tensor_name=t, tile_pos=p, tile_factor=v)) elif constraints == TileConstraint.CANDIDATE: strategy.append(create_custom_tiling_node(TileMode.TENSOR, tile_level=level, tensor_name=t, tile_pos=p, tile_candidate=v)) elif constraints == TileConstraint.MAX: strategy.append(create_custom_tiling_node(TileMode.TENSOR, tile_level=level, tensor_name=t, tile_pos=p, tile_max=v)) elif constraints == TileConstraint.FORBID_ISOLATE: strategy.append(create_custom_tiling_node(TileMode.TENSOR, tile_level=level, tensor_name=t, tile_pos=p, forbid_isolate=v)) elif constraints == TileConstraint.SET_PRIORITY: strategy.append(create_custom_tiling_node(TileMode.TENSOR, tile_level=level, tensor_name=t, tile_pos=p, priority=v)) elif constraints == TileConstraint.SET_EXPANSION: strategy.append(create_custom_tiling_node(TileMode.TENSOR, tile_level=level, tensor_name=t, expansion=v)) else: raise TypeError("Constraint {} is not supported in this api, please use other api" .format(constraints.value)) return strategy @check_input_type(akg.tvm.tensor.Tensor, TileTemplate, TileLevel) def create_template(tensor, template, level=TileLevel.C1): """create template according to given template arg.""" tensor_name = tensor.op.name if template not in TileTemplate: raise ValueError("Invalid template name {0}, must chosen from {1}". format(template, TileTemplate)) if template in [TileTemplate.NCHW, TileTemplate.DEFAULT_FORMAT]: return template_nchw(tensor_name, level) if template == TileTemplate.NC1HWC0: return template_nc1hwc0(tensor_name, level) if template == TileTemplate.NHWC: return template_nhwc(tensor_name, level) return [] def to_tvm_type(value, t_type): """transform integer and string to corresponding type in tvm.""" if isinstance(value, int): return akg.tvm.expr.IntImm("int32", value) if isinstance(value, str): return akg.tvm.expr.StringImm(value) if isinstance(value, (akg.tvm.expr.IntImm, akg.tvm.expr.StringImm)): return value raise TypeError("{} only support integer or string, found {}".format(t_type, type(value))) def create_custom_tiling_node(tile_mode, tile_level=TileLevel.C1, tensor_name=DEFAULT_STRING, tile_pos=DEFAULT_VALUE, tile_band=DEFAULT_VALUE, tile_axis=DEFAULT_VALUE, tile_min=DEFAULT_VALUE, tile_max=DEFAULT_VALUE, tile_mod=DEFAULT_VALUE, tile_factor=DEFAULT_VALUE, tile_candidate=DEFAULT_VALUE, forbid_isolate=DEFAULT_VALUE, axis_info=DEFAULT_STRING, priority=DEFAULT_VALUE, expansion=DEFAULT_VALUE, mem_ratio=double(DEFAULT_VALUE), thread_min=[], thread_max=[], thread_mod=[], block_min=[], block_max=[], block_mod=[]): """default method to create custom tiling node, all values are default except tile mode.""" tile_min = to_tvm_type(tile_min, "tile_min") tile_max = to_tvm_type(tile_max, "tile_max") tile_mod = to_tvm_type(tile_mod, "tile_mod") tile_factor = to_tvm_type(tile_factor, "tile_factor") tile_candidate = to_tvm_type(tile_candidate, "tile_candidate") return akg.tvm.make.node(NODE_TYPE, tile_level=akg.tvm.expr.StringImm(tile_level.value), tile_mode=akg.tvm.expr.StringImm(tile_mode.value), tensor_name=akg.tvm.expr.StringImm(tensor_name), tile_pos=tile_pos, tile_band=tile_band, tile_axis=tile_axis, tile_min=tile_min, tile_max=tile_max, tile_mod=tile_mod, tile_factor=tile_factor, tile_candidate=tile_candidate, forbid_isolate=forbid_isolate, axis_info=akg.tvm.expr.StringImm(axis_info), priority=priority, expansion=expansion, mem_ratio=mem_ratio, thread_min=thread_min, thread_max=thread_max, thread_mod=thread_mod, block_min=block_min, block_max=block_max, block_mod=block_mod) def template_nc1hwc0(tensor_name, level): """create default tiling strategy for nc1hwc0 template.""" node_n = create_custom_tiling_node(TileMode.TENSOR, tile_level=level, tensor_name=tensor_name, tile_pos=0, tile_factor=to_tvm_type(1, "tile_factor")) node_c0 = create_custom_tiling_node(TileMode.TENSOR, tile_level=level, tensor_name=tensor_name, tile_pos=4, tile_max="FULL") return [node_n, node_c0] def template_nchw(tensor_name, level): """create default tiling strategy for nchw template.""" node_n = create_custom_tiling_node(TileMode.TENSOR, tile_level=level, tensor_name=tensor_name, tile_pos=0, tile_factor=to_tvm_type(1, "tile_factor")) node_c = create_custom_tiling_node(TileMode.TENSOR, tile_level=level, tensor_name=tensor_name, tile_pos=1, tile_mod=to_tvm_type(CUBE_UNIT, "tile_factor")) return [node_n, node_c] def template_nhwc(tensor_name, level): """create default tiling strategy for nhwc template.""" node_n = create_custom_tiling_node(TileMode.TENSOR, tile_level=level, tensor_name=tensor_name, tile_pos=0, tile_factor=to_tvm_type(1, "tile_factor")) node_c = create_custom_tiling_node(TileMode.TENSOR, tile_level=level, tensor_name=tensor_name, tile_pos=3, tile_mod=to_tvm_type(CUBE_UNIT, "tile_factor")) return [node_n, node_c] def set_dims(tiling): """Set dim for tiling.""" info = dim.Dim() for d, tile_d in enumerate(tiling): if len(tile_d) == 2: # only c1 and c0 tile index = 0 axis = d c1 = tile_d[0] c0 = tile_d[1] elif len(tile_d) == 4: # index, axis, c1, c0 index = tile_d[0] axis = tile_d[1] c1 = tile_d[2] c0 = tile_d[3] else: raise RuntimeError("Each element in tiling should be length-2 (c1_tile, c0_tile) " "or length-4 (band_index, axis_index, c1_tile, c0_tile)") info.setdim(index=index, axis=axis, tilel1=c1, tilel0=c0) return str(info) def set_dims_by_key(key, map_): """Set dim for tiling by key.""" if key in map_.keys(): return set_dims(map_[key]) return "" def reg_set_dim_func(set_dim_func): """register setdim function.""" def decorate(func_): @wraps(func_) def wrapper(*args, **kwargs): set_dim_func_map[func_.__name__] = set_dim_func return func_(*args, **kwargs) return wrapper return decorate def reg_set_dim_func_by_func(func_, set_dim_func): """register setdim function by function.""" set_dim_func_map[func_.__name__] = set_dim_func def reg_gen_key_func(gen_key_func): """register generated key by function.""" def decorate(func_): @wraps(func_) def wrapper(*args, **kwargs): gen_key_func_map[func_.__name__] = gen_key_func return func_(*args, **kwargs) return wrapper return decorate