|
|
|
@@ -104,12 +104,10 @@ def trans_data(inputs, attrs): |
|
|
|
raise ValueError("length of shape of input_data should be greater than or equal to 2, but got %d" |
|
|
|
% len(shape)) |
|
|
|
m, n = shape[-2:] |
|
|
|
output_shape = [] |
|
|
|
for i in range(0, len(shape) - 2): |
|
|
|
output_shape.append(shape[i]) |
|
|
|
batch_dims = shape[:-2] |
|
|
|
m1 = (m + cs - 1) // cs |
|
|
|
n1 = (n + cs - 1) // cs |
|
|
|
output_shape.extend([n1, m1, cs, cs]) |
|
|
|
output_shape = batch_dims + [n1, m1, cs, cs] |
|
|
|
|
|
|
|
def fcompute(*output_indices): |
|
|
|
input_indices = [] |
|
|
|
@@ -126,8 +124,26 @@ def trans_data(inputs, attrs): |
|
|
|
input_indices.append(n_indice) |
|
|
|
res = tvm.if_then_else(tvm.any(m_indice >= m, n_indice >= n), tvm.const(0, dtype), data(*input_indices)) |
|
|
|
return res |
|
|
|
output = tvm.compute(output_shape, fcompute, name=output_name) |
|
|
|
return output |
|
|
|
|
|
|
|
# If it is implemented with tvm.compute, |
|
|
|
# the generated stmt is difficult to process for poly in the fusion scene |
|
|
|
def kernel_ir(input_, output): |
|
|
|
ib = tvm.ir_builder.create() |
|
|
|
with ib.for_range_n(batch_dims, "bs") as i: |
|
|
|
with ib.for_range(0, m) as i_m: |
|
|
|
with ib.for_range(0, n) as i_n: |
|
|
|
i_m1, i_m0, i_n1, i_n0 = i_m // cs, i_m % cs, i_n // cs, i_n % cs |
|
|
|
output_args = i + [i_n1, i_m1, i_m0, i_n0] |
|
|
|
input_args = i + [i_m, i_n] |
|
|
|
ib.store(output, output_args, |
|
|
|
ib.load(input_, input_args)) |
|
|
|
return ib.get() |
|
|
|
|
|
|
|
if (shape[-1] % cs == 0 and shape[-2] % cs == 0): |
|
|
|
return tvm.extern(output_shape, [data], lambda ins, outs: kernel_ir(ins[0], outs[0]), name=output_name, |
|
|
|
dtype=data.dtype) |
|
|
|
else: |
|
|
|
return tvm.compute(output_shape, fcompute, name=output_name) |
|
|
|
|
|
|
|
# FRACTAL_NZ: zN fractal format |
|
|
|
if (src_format == "DefaultFormat" or src_format == "NCHW") and dst_format == "FRACTAL_NZ": |
|
|
|
|