|
- #!/usr/bin/env python3
- # coding: utf-8
- # Copyright 2019 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- """operator dsl function: four2five"""
- import akg
- import akg.tvm
- from akg.tvm.hybrid import script
- from akg.topi.nn import pad as tvm_pad
- from akg.utils.format_transform import get_shape, get_bytes, to_tvm_const
- from akg.utils import validation_check as vc_util
- from akg.utils import custom_tiling as ct_util
- from akg.utils import dynamic_shape as ds
-
- C_LIMIT_FOR_CAST = 3600
-
-
- def get_attrs():
- """get attrs."""
- attrs = {
- "help_tiling": 0,
- "pragma_sink_last_axis": False,
- "enable_pre_poly_loop_partition": False
- }
- return attrs
-
-
- def get_dynamic_attrs():
- """get dynamic attrs."""
- attrs = {
- "help_tiling": 0,
- "pragma_sink_last_axis": False,
- "enable_pre_poly_loop_partition": True,
- "dynamic_shape_bound": 65535,
- "enable_post_poly_loop_partition": False,
- "enable_double_buffer:": False,
- # "enable_scalar_align": True,
- }
- return attrs
-
-
- four2five_set_dim_map = {
- "((1, 1, 7, 7), 'NCHW', 'float32', 'float16')": ((1, 1), (1, 1), (7, 1), (7, 1), (16, 1)),
- "((1, 7, 7), 'NCHW', 'float32', 'float16')": ((1, 1), (7, 1), (7, 1), (16, 1)),
- "((1, 1, I2, I3), 'NCHW', 'float32', 'float16')": ((1, 1), (1, 1), (1, 1), (129, 1), (2048, 1)),
- }
-
-
- def four2five_set_dim_func(data, format_, dst_type):
- """set dim info for attr."""
- shape = get_shape(data)
- if format_ == 'NCHW':
- n, _, h, w = shape
- else:
- n, h, w, _ = shape
-
- shape[0] = 1
- if h != 1 and w != 1:
- if format_ == 'NCHW' and shape[1] > 16:
- shape[1] = 1
- if format_ == 'NHWC' and shape[-1] > 16:
- shape[-1] = 1
-
- if n == 1:
- shape.remove(shape[0])
-
- hash_key = str((tuple(shape), format_, data.dtype, dst_type))
- return ct_util.set_dims_by_key(hash_key, four2five_set_dim_map), hash_key
-
-
- def four2five_tiling_strategy(tensor, input_format, expansion=None):
- """Custom tiling strategy for four2five op."""
- strategy = ct_util.create_template(tensor=tensor,
- template=ct_util.TileTemplate.NC1HWC0)
- if input_format == "NHWC" or expansion:
- priority_map = {4: 0, 1: 1, 3: 2, 2: 3, 0: 4} # tile in C0->C1->W->H->N sequence
- for pos, priority in priority_map.items():
- strategy.append(ct_util.create_constraint_on_tensor(tensor=tensor,
- values=priority,
- constraints=ct_util.TileConstraint.SET_PRIORITY,
- tensor_pos=pos)[0])
- if expansion:
- strategy.append(ct_util.create_constraint_on_tensor(tensor=tensor,
- values=expansion,
- constraints=ct_util.TileConstraint.SET_EXPANSION)[0])
- return strategy
-
- def four2five_tiling_strategy_dynamic(tensor, input_format):
- """Custom tiling strategy for four2five op."""
- strategy = list()
- if input_format == "NCHW":
- shape = get_shape(tensor)
- if shape[1] == 1:
- strategy.append(ct_util.create_constraint_on_tensor(tensor, 1, ct_util.TileConstraint.FACTOR, 0)[0])
- strategy.append(ct_util.create_constraint_on_tensor(tensor, 1, ct_util.TileConstraint.FACTOR, 1)[0])
- strategy.append(ct_util.create_constraint_on_tensor(tensor, 1, ct_util.TileConstraint.FACTOR, 2)[0])
- strategy.append(ct_util.create_constraint_on_tensor(tensor, 112, ct_util.TileConstraint.FACTOR, 3)[0])
- strategy.append(ct_util.create_constraint_on_tensor(tensor, 16, ct_util.TileConstraint.FACTOR, 4)[0])
- elif shape[1] == 128:
- strategy.append(ct_util.create_constraint_on_tensor(tensor, 1, ct_util.TileConstraint.FACTOR, 0)[0])
- strategy.append(ct_util.create_constraint_on_tensor(tensor, 1, ct_util.TileConstraint.FACTOR, 1)[0])
- strategy.append(ct_util.create_constraint_on_tensor(tensor, 1, ct_util.TileConstraint.FACTOR, 2)[0])
- strategy.append(ct_util.create_constraint_on_tensor(tensor, "FULL", ct_util.TileConstraint.MAX, 3)[0])
- strategy.append(ct_util.create_constraint_on_tensor(tensor, 16, ct_util.TileConstraint.FACTOR, 4)[0])
- return strategy
-
- @vc_util.check_input_type(akg.tvm.tensor.Tensor, str, str, bool)
- def four2five(data, format_, dst_dtype='float16', need_custom_tiling=True):
- """
- Convert 4-dims "data" to 5-dims,the format of "data" is defined in "format_"
-
- Args:
- data (tvm.tensor.Tensor): 4-dims tensor of type float16, float32
- format_ (str): a str defined the format of "data"
- dst_dtype (str): a str defined the type of output, could be float16 or float32
-
- Returns:
- 5-dims tvm.tensor.Tensor,type is defined by dst_dtype,
- which shape is [N, ceil(C / 16), H, W, 16] and attr about tiling args
-
- Raises:
- ValueError: If the type of format_ is invalid.
-
- """
- # Check dtype
- vc_util.ops_dtype_check(data.dtype, vc_util.DtypeForDavinci.ALL_FLOAT)
- # Check shape
- shape = get_shape(data)
- vc_util.davinci_format_check(shape, format_, dim=4)
-
- # Check format
- if format_ not in ['NCHW', 'NHWC']:
- raise ValueError("{} format is not support, four2five only support NCHW and NHWC format input"
- .format(format_))
- last_channel = 16
- if format_ == "NCHW":
- bs, c, h, w = get_shape(data)
- else:
- bs, h, w, c = get_shape(data)
- pad_c = c
- if c % last_channel != 0:
- pad_c = (c + 15) // last_channel * last_channel
- c1 = pad_c // last_channel
- c0 = last_channel
- is_dynamic = ds.shape_is_dynamic(data)
- if not is_dynamic:
- attrs = get_attrs()
- else:
- attrs = get_dynamic_attrs()
- # Check size c when casting happens
- if data.dtype != dst_dtype and c0 * c1 >= C_LIMIT_FOR_CAST:
- raise ValueError("When input and output data type is not matched, shape of 'c' axis should not exceed {}, "
- "while currently set is {}".format(C_LIMIT_FOR_CAST, c0 * c1))
-
- @script(capture=locals())
- def nchw_to_nc1hwc0_step(inputs, bs, c1, h, w, c0):
- output = allocate((bs, c1, h, c0, w), inputs.dtype, "local")
- for n_i in range(bs):
- for c_i in range(c1):
- for h_i in range(h):
- for w_i in range(w):
- for c_i0 in range(c0):
- output[n_i, c_i, h_i, c_i0, w_i] = inputs[n_i, c_i * last_channel + c_i0, h_i, w_i]
- output1 = allocate((bs, c1, h, w, c0), inputs.dtype, "local")
- for n_i in range(bs):
- for c_i in range(c1):
- for h_i in range(h):
- for w_i in range(w):
- for c_i0 in range(c0):
- output1[n_i, c_i, h_i, w_i, c_i0] = output[n_i, c_i, h_i, c_i0, w_i]
- return output1
-
- @script(capture=locals())
- def nchw_to_nc1hwc0(inputs, bs, c1, h, w, c0):
- output = allocate((bs, c1, h, w, c0), inputs.dtype, "local")
- for n_i in range(bs):
- for c_i in range(c1):
- for h_i in range(h):
- for w_i in range(w):
- for c_i0 in range(c0):
- output[n_i, c_i, h_i, w_i, c_i0] = inputs[n_i, c_i * last_channel + c_i0, h_i, w_i]
- return output
-
- @script(capture=locals())
- def nhwc_to_nc1hwc0(inputs, zero, bs, c1, h, w, c0):
- output = allocate((bs, c1, h, w, c0), inputs.dtype, "local")
- for n_i in range(bs):
- for c_i in range(c1):
- for h_i in range(h):
- for w_i in range(w):
- for c_i0 in range(c0):
- if c_i * last_channel + c_i0 < c:
- output[n_i, c_i, h_i, w_i, c_i0] = inputs[n_i, h_i, w_i, c_i * last_channel + c_i0]
- else:
- output[n_i, c_i, h_i, w_i, c_i0] = zero
-
- return output
-
- cast_data = data
- need_cast = data.dtype == 'float32' and dst_dtype == 'float16'
- if c % last_channel != 0 or need_cast:
- expansion = int(ct_util.BLOCK_SIZE / get_bytes(data.dtype))
- else:
- expansion = None
- # float32 -> float16, need to cast before transform
- if need_cast:
- cast_data = akg.lang.cce.cast_to(data, dst_dtype)
-
- zero_ = akg.tvm.const(0.0, cast_data.dtype)
- if format_ == "NCHW":
- if c % last_channel != 0:
- pad_shape = [bs, pad_c, h, w]
- if h == 1 and w == 1:
- # if h and w both are 1, it is pad last dim case
- output_shape = [bs, pad_c // last_channel, h, w, last_channel]
-
- output = akg.tvm.compute(output_shape,
- lambda i, c1, k, l, c0: akg.tvm.expr.Select(
- c0 < c - c1 * last_channel, cast_data[i, c1 * last_channel + c0, k, l],
- akg.tvm.const(0, cast_data.dtype)),
- name="output")
- else:
- # if need to pad c dim, separate transpose to two steps
- # first is nchw -> nc1hc0w, second is nc1hc0w -> nc1hwc0
- pad_data = akg.tvm.compute(pad_shape,
- lambda i, j, k, l: akg.tvm.expr.Select(j < c, cast_data[i, j, k, l], zero_),
- name="pad_data")
- output = nchw_to_nc1hwc0_step(
- pad_data,
- to_tvm_const(bs),
- to_tvm_const(c1),
- to_tvm_const(h),
- to_tvm_const(w),
- to_tvm_const(c0))
-
- else:
- if not is_dynamic and data.dtype == "float16" and h * w % last_channel == 0 and h * w < 3600:
- output_shape = [bs, c1, h, w, c0]
- output = akg.tvm.compute(output_shape, lambda n, c1, h, w, c0:
- akg.lang.cce.four2five_nchw(cast_data[n, c1 * last_channel + c0, h, w]),
- name="output")
-
- else:
- output = nchw_to_nc1hwc0(
- cast_data,
- to_tvm_const(bs),
- to_tvm_const(c1),
- to_tvm_const(h),
- to_tvm_const(w),
- to_tvm_const(c0))
-
- else:
- if not is_dynamic and c < last_channel:
- rank = 5 # (n, c1, h, w, c0)
- pad_before = []
- pad_after = []
- for _ in range(rank):
- pad_before.append(0)
- pad_after.append(0)
- pad_after[-1] = last_channel - c
- # As c < last_channel, c1 is 1
- output = akg.tvm.compute((bs, c1, h, w, c), lambda bs_i, _, h_i, w_i, c_i: cast_data[
- bs_i, h_i, w_i, c_i], name="output")
- output = tvm_pad(output, pad_before, pad_after=pad_after, name='pad_output')
- else:
- output = nhwc_to_nc1hwc0(
- cast_data,
- zero_,
- to_tvm_const(bs),
- to_tvm_const(c1),
- to_tvm_const(h),
- to_tvm_const(w),
- to_tvm_const(c0))
-
- # float16 -> float32, need to cast after transform
- if data.dtype == 'float16' and dst_dtype == 'float32':
- output = akg.lang.cce.cast_to(output, dst_dtype)
-
- vc_util.davinci_format_check(output.shape, "NC1HWC0", dim=5)
-
- if not is_dynamic:
- dim_info, _ = four2five_set_dim_func(data, format_, dst_dtype)
- if dim_info != "":
- attrs["dim"] = dim_info
- if need_custom_tiling:
- attrs["custom_tiling"] = four2five_tiling_strategy(output, format_, expansion)
- elif need_custom_tiling:
- attrs["custom_tiling"] = four2five_tiling_strategy_dynamic(output, format_)
-
- if is_dynamic:
- attrs["enable_feature_library_pre_poly"] = True
- return output, attrs
|