|
|
|
@@ -262,7 +262,11 @@ def convert_fracal_shape(ori_shape, fractal): |
|
|
|
def matmul_str(inputs, output, attr): |
|
|
|
|
|
|
|
left_format = get_attr(attr, "left_format") |
|
|
|
if left_format == None: |
|
|
|
left_format = get_attr(attr, "pri_format") |
|
|
|
right_format = get_attr(attr, "right_format") |
|
|
|
if right_format == None: |
|
|
|
right_format = get_attr(attr, "pri_format") |
|
|
|
trans_a = get_attr(attr, "transpose_a") |
|
|
|
trans_b = get_attr(attr, "transpose_b") |
|
|
|
left_input = inputs[0][0] |
|
|
|
@@ -283,7 +287,11 @@ def matmul_str(inputs, output, attr): |
|
|
|
right_ori_shape = convert_fracal_shape(right_input['shape'], "zN") |
|
|
|
right_trans_str = get_trans_data_str(right_input_name, right_input_name, right_ori_shape, right_format, 'DefaultFormat') |
|
|
|
res = res + right_trans_str + "\n" |
|
|
|
matmul_str = np_matmul_str(inputs, output, attr) |
|
|
|
has_batch = (len(left_input['shape']) > 4) |
|
|
|
if has_batch: |
|
|
|
matmul_str = batchmatmul_str(inputs, output, attr) |
|
|
|
else: |
|
|
|
matmul_str = np_matmul_str(inputs, output, attr) |
|
|
|
res = res + matmul_str + "\n" |
|
|
|
|
|
|
|
has_bias = (len(inputs) > 2) |
|
|
|
@@ -388,7 +396,7 @@ op_dsl = { |
|
|
|
"Transpose": lambda inputs, output, attr: transpose_str(inputs, output, attr), |
|
|
|
"TransData": trans_data_dsl, |
|
|
|
"BroadcastTo": lambda inputs, output, attr: broadcast_str(inputs, output, attr), |
|
|
|
"BatchMatMul": lambda inputs, output, attr: batchmatmul_str(inputs, output, attr), |
|
|
|
"BatchMatMul": lambda inputs, output, attr: matmul_str(inputs, output, attr), |
|
|
|
"Assign": lambda inputs, output, attr: "%s = %s; %s = %s" % |
|
|
|
(get_input(inputs[0][0]), get_input(inputs[1][0]), output[0]['tensor_name'], |
|
|
|
get_input(inputs[1][0])), |
|
|
|
|