Browse Source

batchmatmul tunning adaption

pull/87/head
wYann 5 years ago
parent
commit
83a3edfa14
2 changed files with 11 additions and 3 deletions
  1. +1
    -1
      python/akg/auto_tune/job.py
  2. +10
    -2
      tests/common/gen_json_data.py

+ 1
- 1
python/akg/auto_tune/job.py View File

@@ -75,7 +75,7 @@ def gen_bool_list(attr_list):

def get_matmul_op_desc(json_input):
for op_desc in json_input["op_desc"]:
if op_desc["name"] == "MatMul":
if op_desc["name"] == "MatMul" or op_desc["name"] == "BatchMatMul":
return op_desc
return None



+ 10
- 2
tests/common/gen_json_data.py View File

@@ -262,7 +262,11 @@ def convert_fracal_shape(ori_shape, fractal):
def matmul_str(inputs, output, attr):

left_format = get_attr(attr, "left_format")
if left_format == None:
left_format = get_attr(attr, "pri_format")
right_format = get_attr(attr, "right_format")
if right_format == None:
right_format = get_attr(attr, "pri_format")
trans_a = get_attr(attr, "transpose_a")
trans_b = get_attr(attr, "transpose_b")
left_input = inputs[0][0]
@@ -283,7 +287,11 @@ def matmul_str(inputs, output, attr):
right_ori_shape = convert_fracal_shape(right_input['shape'], "zN")
right_trans_str = get_trans_data_str(right_input_name, right_input_name, right_ori_shape, right_format, 'DefaultFormat')
res = res + right_trans_str + "\n"
matmul_str = np_matmul_str(inputs, output, attr)
has_batch = (len(left_input['shape']) > 4)
if has_batch:
matmul_str = batchmatmul_str(inputs, output, attr)
else:
matmul_str = np_matmul_str(inputs, output, attr)
res = res + matmul_str + "\n"

has_bias = (len(inputs) > 2)
@@ -388,7 +396,7 @@ op_dsl = {
"Transpose": lambda inputs, output, attr: transpose_str(inputs, output, attr),
"TransData": trans_data_dsl,
"BroadcastTo": lambda inputs, output, attr: broadcast_str(inputs, output, attr),
"BatchMatMul": lambda inputs, output, attr: batchmatmul_str(inputs, output, attr),
"BatchMatMul": lambda inputs, output, attr: matmul_str(inputs, output, attr),
"Assign": lambda inputs, output, attr: "%s = %s; %s = %s" %
(get_input(inputs[0][0]), get_input(inputs[1][0]), output[0]['tensor_name'],
get_input(inputs[1][0])),


Loading…
Cancel
Save