Browse Source

Pre Merge pull request !88 from hanhuifeng/transdata_composite

pull/88/MERGE
hanhuifeng Gitee 5 years ago
parent
commit
67d3bcb5a4
2 changed files with 23 additions and 7 deletions
  1. +1
    -1
      python/akg/composite/build_module.py
  2. +22
    -6
      python/akg/composite/topi.py

+ 1
- 1
python/akg/composite/build_module.py View File

@@ -557,7 +557,7 @@ def _enable_auto_inline(kernel_info):
# For the TransData op operator, if the inline is not performed,
# the operator fusion scene is difficult to handle for the poly.
# So are MatMul/BatchMatMul with bias.
if op['name'] in ["TransData", "MatMul", "BatchMatMul"]:
if op['name'] in ["MatMul", "BatchMatMul"]:
return True
# For the Ascend, turn 'enable_auto_inline' off for composite op by default.
return False


+ 22
- 6
python/akg/composite/topi.py View File

@@ -104,12 +104,10 @@ def trans_data(inputs, attrs):
raise ValueError("length of shape of input_data should be greater than or equal to 2, but got %d"
% len(shape))
m, n = shape[-2:]
output_shape = []
for i in range(0, len(shape) - 2):
output_shape.append(shape[i])
batch_dims = shape[:-2]
m1 = (m + cs - 1) // cs
n1 = (n + cs - 1) // cs
output_shape.extend([n1, m1, cs, cs])
output_shape = batch_dims + [n1, m1, cs, cs]

def fcompute(*output_indices):
input_indices = []
@@ -126,8 +124,26 @@ def trans_data(inputs, attrs):
input_indices.append(n_indice)
res = tvm.if_then_else(tvm.any(m_indice >= m, n_indice >= n), tvm.const(0, dtype), data(*input_indices))
return res
output = tvm.compute(output_shape, fcompute, name=output_name)
return output

# If it is implemented with tvm.compute,
# the generated stmt is difficult to process for poly in the fusion scene
def kernel_ir(input_, output):
ib = tvm.ir_builder.create()
with ib.for_range_n(batch_dims, "bs") as i:
with ib.for_range(0, m) as i_m:
with ib.for_range(0, n) as i_n:
i_m1, i_m0, i_n1, i_n0 = i_m // cs, i_m % cs, i_n // cs, i_n % cs
output_args = i + [i_n1, i_m1, i_m0, i_n0]
input_args = i + [i_m, i_n]
ib.store(output, output_args,
ib.load(input_, input_args))
return ib.get()

if (shape[-1] % cs == 0 and shape[-2] % cs == 0):
return tvm.extern(output_shape, [data], lambda ins, outs: kernel_ir(ins[0], outs[0]), name=output_name,
dtype=data.dtype)
else:
return tvm.compute(output_shape, fcompute, name=output_name)

# FRACTAL_NZ: zN fractal format
if (src_format == "DefaultFormat" or src_format == "NCHW") and dst_format == "FRACTAL_NZ":


Loading…
Cancel
Save