| @@ -33,5 +33,5 @@ config = ed({ | |||
| "save_checkpoint_path": "./", | |||
| "label_smooth": 1, | |||
| "label_smooth_factor": 0.1, | |||
| "frequency": 278 | |||
| "frequency": 834 | |||
| }) | |||
| @@ -22,12 +22,6 @@ from mindspore.nn.optim.optimizer import Optimizer | |||
| from mindspore.ops import functional as F, composite as C, operations as P | |||
| from mindspore.parallel._utils import _get_device_num, _get_mirror_mean | |||
| from cus_ops.cus_matmul_cube_dense_right import CusMatMulCubeDenseRight | |||
| from cus_ops.cus_matmul_cube_fracz_left_cast import CusMatMulCubeFraczLeftCast | |||
| from cus_ops.cus_matmul_cube_dense_left import CusMatMulCubeDenseLeft | |||
| from cus_ops.cus_matmul_cube_fracz_right_mul import CusMatMulCubeFraczRightMul | |||
| from model.grad_reducer_thor import DistributedGradReducerThor | |||
| momentum_opt = C.MultitypeFuncGraph("momentum_opt") | |||
| @@ -68,10 +62,10 @@ class THOR(Optimizer): | |||
| self.matrix_G = ParameterTuple(matrix_G) | |||
| self.A_inv_max = ParameterTuple(A_inv_max) | |||
| self.G_inv_max = ParameterTuple(G_inv_max) | |||
| self.cube_matmul_left = CusMatMulCubeFraczLeftCast() | |||
| self.cube_matmul_left_fc = CusMatMulCubeDenseLeft() | |||
| self.cube_matmul_right_fc = CusMatMulCubeDenseRight() | |||
| self.cube_matmul_right_mul = CusMatMulCubeFraczRightMul() | |||
| self.cube_matmul_left = P.CusMatMulCubeFraczLeftCast() | |||
| self.cube_matmul_left_fc = P.CusMatMulCubeDenseLeft() | |||
| self.cube_matmul_right_fc = P.CusMatMulCubeDenseRight() | |||
| self.cube_matmul_right_mul = P.CusMatMulCubeFraczRightMul() | |||
| self.transpose = P.Transpose() | |||
| self.shape = P.Shape() | |||
| self.reshape = P.Reshape() | |||
| @@ -23,19 +23,9 @@ from mindspore.common.tensor import Tensor | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore.nn.layer.activation import get_activation | |||
| from mindspore.ops import operations as P | |||
| from cus_ops.cus_batch_matmul import CusBatchMatMul | |||
| from cus_ops.cus_cholesky_trsm import CusCholeskyTrsm | |||
| from cus_ops.cus_fused_abs_max1 import CusFusedAbsMax1 | |||
| from cus_ops.cus_img2col import CusImg2Col | |||
| from cus_ops.cus_matmul_cube import CusMatMulCube | |||
| from cus_ops.cus_matrix_combine import CusMatrixCombine | |||
| from cus_ops.cus_transpose02314 import CusTranspose02314 | |||
| import numpy as np | |||
| C0 = 16 | |||
| def caculate_device_shape(matrix_dim, channel, is_A): | |||
| ll = (0) | |||
| if is_A: | |||
| @@ -153,11 +143,11 @@ class Conv2d_Thor(_Conv): | |||
| group=self.group | |||
| ) | |||
| self.img2col = CusImg2Col(ksizes=ksizes, strides=strides) | |||
| self.cube_matmul = CusMatMulCube(transpose_a=True) | |||
| self.matrix_combine = CusMatrixCombine() | |||
| self.cholesky = CusCholeskyTrsm() | |||
| self.transpose02314 = CusTranspose02314() | |||
| self.img2col = P.CusImg2Col(ksizes=ksizes, strides=strides) | |||
| self.cube_matmul = P.CusMatMulCube(transpose_a=True) | |||
| self.matrix_combine = P.CusMatrixCombine() | |||
| self.cholesky = P.CusCholeskyTrsm() | |||
| self.transpose02314 = P.CusTranspose02314() | |||
| self.matrix_A_dim = self.in_channels * self.kernel_size[0] * self.kernel_size[1] | |||
| self.matrix_G_dim = self.out_channels | |||
| self.matrix_A_device_shape, self.matrix_A_device_dim = caculate_device_shape(self.matrix_A_dim, | |||
| @@ -190,7 +180,7 @@ class Conv2d_Thor(_Conv): | |||
| self.mul = P.Mul() | |||
| self.cast = P.Cast() | |||
| self.damping = Tensor(damping) | |||
| self.vector_matmul = CusBatchMatMul() | |||
| self.vector_matmul = P.CusBatchMatMul() | |||
| self.diag_block_dim = 128 | |||
| self.channels_slice_flag = False | |||
| if self.in_channels % C0 != 0: | |||
| @@ -221,8 +211,8 @@ class Conv2d_Thor(_Conv): | |||
| self.dampingA = Tensor(np.identity(dampingA_dim), mstype.float32) | |||
| self.dampingG = Tensor(np.identity(dampingG_dim), mstype.float32) | |||
| self.fused_abs_max1 = CusFusedAbsMax1([self.matrix_A_dim, self.matrix_A_dim]) | |||
| self.fused_abs_max2 = CusFusedAbsMax1() | |||
| self.fused_abs_max1 = P.CusFusedAbsMax1([self.matrix_A_dim, self.matrix_A_dim]) | |||
| self.fused_abs_max2 = P.CusFusedAbsMax1() | |||
| self.log = P.Log() | |||
| self.exp = P.Exp() | |||
| self.sqrt = P.Sqrt() | |||
| @@ -375,9 +365,9 @@ class Dense_Thor(Cell): | |||
| self.fake_G = Tensor(np.zeros([63, 63, 16, 16]).astype(np.float16)) | |||
| self.matmul = P.MatMul(transpose_b=True) | |||
| self.cube_matmul = CusMatMulCube(transpose_a=True) | |||
| self.matrix_combine = CusMatrixCombine() | |||
| self.cholesky = CusCholeskyTrsm() | |||
| self.cube_matmul = P.CusMatMulCube(transpose_a=True) | |||
| self.matrix_combine = P.CusMatrixCombine() | |||
| self.cholesky = P.CusCholeskyTrsm() | |||
| self.shape = P.Shape() | |||
| self.reshape = P.Reshape() | |||
| self.transpose = P.Transpose() | |||
| @@ -386,7 +376,7 @@ class Dense_Thor(Cell): | |||
| self.cast = P.Cast() | |||
| self.damping = Tensor(damping) | |||
| self.loss_scale = Tensor(1 / loss_scale, mstype.float16) | |||
| self.vector_matmul = CusBatchMatMul() | |||
| self.vector_matmul = P.CusBatchMatMul() | |||
| self.pad = P.Pad(((0, 24), (0, 24))) | |||
| self.pad1 = P.Pad(((0, 8), (0, 8))) | |||
| self.slice = P.Slice() | |||
| @@ -396,8 +386,8 @@ class Dense_Thor(Cell): | |||
| self.axis = 0 | |||
| self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False) | |||
| self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False) | |||
| self.fused_abs_max1 = CusFusedAbsMax1([1000, 1000]) | |||
| self.fused_abs_max2 = CusFusedAbsMax1() | |||
| self.fused_abs_max1 = P.CusFusedAbsMax1([1000, 1000]) | |||
| self.fused_abs_max2 = P.CusFusedAbsMax1() | |||
| self.log = P.Log() | |||
| self.exp = P.Exp() | |||
| self.dampingA = Tensor(np.identity(2048), mstype.float32) | |||
| @@ -33,6 +33,7 @@ cus_batchmatmul_op_info = TBERegOp("CusBatchMatMul") \ | |||
| def _get_flattern_shape(shape): | |||
| """_get_flattern_shape""" | |||
| flattern_shape = 1 | |||
| for dim in shape: | |||
| flattern_shape *= dim | |||
| @@ -40,6 +41,7 @@ def _get_flattern_shape(shape): | |||
| def _inner_matmul_new(tik_instance, dtype, input1, input1_index, input2, input2_index, res, res_index): | |||
| """_inner_matmul_new""" | |||
| input_1_local_UB = tik_instance.Tensor(dtype, [128], name="input_1_local_UB", scope=tik.scope_ubuf) | |||
| t_1_0_local_UB = tik_instance.Tensor(dtype, [64 * 128], name="t_1_0_local_UB", scope=tik.scope_ubuf) | |||
| tik_instance.data_move(input_1_local_UB, input1[input1_index], 0, 1, 16, 0, 0) | |||
| @@ -71,6 +73,7 @@ def _inner_matmul_new(tik_instance, dtype, input1, input1_index, input2, input2_ | |||
| def _inner_matmul_new_1_64_32_64(tik_instance, dtype, input1, input1_index, input2, input2_index, res, res_index): | |||
| """_inner_matmul_new_1_64_32_64""" | |||
| input_1_local_UB = tik_instance.Tensor(dtype, [64], name="input_1_local_UB", scope=tik.scope_ubuf) | |||
| tik_instance.data_move(input_1_local_UB, input1[input1_index], 0, 1, 8, 0, 0) | |||
| with tik_instance.for_range(0, 2, thread_num=2) as thread_idx2: | |||
| @@ -90,6 +93,7 @@ def _inner_matmul_new_1_64_32_64(tik_instance, dtype, input1, input1_index, inpu | |||
| @op_info_register(cus_batchmatmul_op_info) | |||
| def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=True, kernel_name="batchmatmul"): | |||
| """CusBatchMatMul""" | |||
| if util.get_product_version() == util.VERSION_MINI: | |||
| tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) | |||
| else: | |||
| @@ -116,7 +120,6 @@ def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=Tr | |||
| # if not transpose_a and transpose_b: | |||
| batch, m, k = x1_shape | |||
| _, n, _ = x2_shape | |||
| input1_shape = _get_flattern_shape(x1_shape) | |||
| input1 = tik_instance.Tensor(dtype, input1_shape, name="input1", scope=tik.scope_gm) | |||
| @@ -32,6 +32,7 @@ cus_cholesky_trsm_op_info = TBERegOp("CusCholeskyTrsm") \ | |||
| @op_info_register(cus_cholesky_trsm_op_info) | |||
| def CusCholeskyTrsm(input_x, output, kernel_name): | |||
| """CusCholeskyTrsm""" | |||
| input_x_shape = input_x.get("shape") | |||
| output_shape = output.get("shape") | |||
| split_dim = 128 | |||
| @@ -33,6 +33,7 @@ cus_fused_abs_max1_op_info = TBERegOp("CusFusedAbsMax1") \ | |||
| @op_info_register(cus_fused_abs_max1_op_info) | |||
| def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_max1"): | |||
| """CusFusedAbsMax1""" | |||
| input_x_shape = input_x.get("shape") | |||
| output_shape = output.get("shape") | |||
| @@ -203,9 +204,9 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m | |||
| tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) | |||
| elif (input_x_shape[0] == 8 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( | |||
| input_x_shape[0] == 32 and input_x_shape[1] == 16) or ( | |||
| input_x_shape[0] == 16 and input_x_shape[1] == 32): | |||
| input_x_shape[0] == 16 and input_x_shape[1] == 32): | |||
| if (input_x_shape[0] == 8 and input_x_shape[1] == 128 and input_x_shape[2] == 128) and origin_shape[ | |||
| 0] == 1000: | |||
| 0] == 1000: | |||
| input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) | |||
| res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) | |||
| blocks = 32 | |||
| @@ -257,7 +258,7 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m | |||
| tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) | |||
| elif (input_x_shape[0] == 8 and input_x_shape[1] == 128 and input_x_shape[2] == 128) and origin_shape[ | |||
| 0] == 1001: | |||
| 0] == 1001: | |||
| input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) | |||
| res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) | |||
| blocks = 32 | |||
| @@ -350,7 +351,7 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m | |||
| tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) | |||
| elif (input_x_shape[0] == 16 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( | |||
| input_x_shape[0] == 16 and input_x_shape[1] == 64) or ( | |||
| input_x_shape[0] == 64 and input_x_shape[1] == 16): | |||
| input_x_shape[0] == 64 and input_x_shape[1] == 16): | |||
| input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) | |||
| res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) | |||
| total_elements = 1 | |||
| @@ -36,6 +36,7 @@ cus_img2col_info = TBERegOp("CusImg2Col") \ | |||
| @op_info_register(cus_img2col_info) | |||
| def CusImg2Col(input_x, output, ksizes, strides, dilates, mode, kernel_name="img2col"): | |||
| """CusImg2Col""" | |||
| input_x_shape = input_x.get("shape") | |||
| input_x_dtype = input_x.get("dtype") | |||
| N, C1, H, W, C0 = input_x_shape | |||
| @@ -64,7 +65,7 @@ def CusImg2Col(input_x, output, ksizes, strides, dilates, mode, kernel_name="img | |||
| ((32, 8, 28, 28, 16), 'float16', (1, 1), (1, 1)), | |||
| ((32, 32, 28, 28, 16), 'float16', (1, 1), (1, 1)), | |||
| ((32, 16, 14, 14, 16), 'float16', (1, 1), (1, 1)), | |||
| ((32, 16, 56, 56, 16), 'float16', (1, 1), (1, 1)), ] | |||
| ((32, 16, 56, 56, 16), 'float16', (1, 1), (1, 1)),] | |||
| if input_shape not in supported_shape: | |||
| raise RuntimeError("input_shape %s is not supported" % str(input_shape)) | |||
| @@ -17,10 +17,9 @@ limitations under the License. | |||
| matmul | |||
| """ | |||
| from __future__ import absolute_import | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| import te.lang.cce | |||
| import te.platform.cce_params as cce | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| from te import tik | |||
| from te import tvm | |||
| from topi import generic | |||
| @@ -128,7 +127,7 @@ def _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b): | |||
| if n_shape % cce.BLOCK_IN != 0 and n_shape != 1: | |||
| raise RuntimeError("input shape N should be 1 or multiple of %d" % cce.BLOCK_IN) | |||
| if len(shape_bias): | |||
| if len(shape_bias) != 0: | |||
| if len(shape_bias) == 1: | |||
| if is_gevm or is_gemv: | |||
| if shape_bias[0] != m_shape * n_shape: | |||
| @@ -145,16 +144,19 @@ def _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b): | |||
| def _get_bias(shape_bias): | |||
| bias_length = shape_bias[0] | |||
| shb = [] | |||
| if bias_length % 16 == 0: | |||
| return shape_bias | |||
| shb = shape_bias | |||
| else: | |||
| bias_length = (bias_length // 16) * 16 + 16 | |||
| shape_bias = [] | |||
| shape_bias.append(bias_length) | |||
| return shape_bias | |||
| shb = shape_bias | |||
| return shb | |||
| def _get_input_shape(shape_x): | |||
| """_get_input_shape""" | |||
| dim_a = shape_x[0] | |||
| dim_b = shape_x[1] | |||
| res = [] | |||
| @@ -173,6 +175,7 @@ def _get_input_shape(shape_x): | |||
| def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): | |||
| """check_supported""" | |||
| shape_a = input_x1.get("shape") | |||
| shape_b = input_x2.get("shape") | |||
| print("shape_a: ", shape_a) | |||
| @@ -183,8 +186,6 @@ def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, t | |||
| util.check_shape_rule(shape_b) | |||
| util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) | |||
| util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) | |||
| if bias is not None and bool(bias): | |||
| shape_bias = bias.get("shape") | |||
| try: | |||
| trans_a_f = bool(1 - trans_a) | |||
| if src_dtype == "float32" or src_dtype == "int32": | |||
| @@ -250,7 +251,7 @@ def CusMatMulCubeDenseLeft(input_x1, input_x2, bias=None, output_y={}, trans_a=F | |||
| """ | |||
| calculating matrix multiplication with bias, C = A*B + bias, support input | |||
| data with fractal format. | |||
| Parameters: | |||
| shape_a: list or tuple | |||
| Shape of the first tensor a with rank > 1 | |||
| @@ -269,7 +270,7 @@ def CusMatMulCubeDenseLeft(input_x1, input_x2, bias=None, output_y={}, trans_a=F | |||
| If True, the input data format of a and b must be fractal format | |||
| shape_bias: list or tuple | |||
| Shape of bias, only support the input data format with ND | |||
| Returns | |||
| ------- | |||
| None | |||
| @@ -2,19 +2,19 @@ | |||
| # -*- coding:utf-8 -*- | |||
| """ | |||
| copyright 2020 Huawei Technologies Co., Ltd | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License == distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| matmul | |||
| """ | |||
| from __future__ import absolute_import | |||
| @@ -43,11 +43,12 @@ matmul_cube_dense_right_op_info = TBERegOp("CusMatMulCubeDenseRight") \ | |||
| @op_info_register(matmul_cube_dense_right_op_info) | |||
| def CusMatMulCubeDenseRight(input_x1, input_x2, input_x3, bias=None, output_y={}, trans_a=False, trans_b=False, | |||
| kernel_name="matmulcube"): | |||
| """CusMatMulCubeDenseRight""" | |||
| shape_a_temp = (128, 63, 16, 16) | |||
| shape_b_temp = (128, 128, 16, 16) | |||
| shape_output = output_y.get("shape") | |||
| matrix_max_shape = (1,) | |||
| support_shape = [(shape_a_temp, shape_b_temp, matrix_max_shape), ] | |||
| support_shape = [(shape_a_temp, shape_b_temp, matrix_max_shape),] | |||
| shape_a_input = input_x1.get("shape") | |||
| shape_b_input = input_x2.get("shape") | |||
| matrix_max_input = input_x3.get("shape") | |||
| @@ -62,7 +63,7 @@ def CusMatMulCubeDenseRight(input_x1, input_x2, input_x3, bias=None, output_y={} | |||
| tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) | |||
| input_x1 = tik_instance.Tensor("float16", shape_a_temp, name="left_matrix", scope=tik.scope_gm) | |||
| input_x2 = tik_instance.Tensor("float16", shape_b_temp, name="right_matrix", scope=tik.scope_gm) | |||
| input_x3 = tik_instance.Tensor("float32", [1, ], name="matrix_max", scope=tik.scope_gm) | |||
| input_x3 = tik_instance.Tensor("float32", [1,], name="matrix_max", scope=tik.scope_gm) | |||
| resMatmul = tik_instance.Tensor("float32", shape_output, name="output", scope=tik.scope_gm) | |||
| with tik_instance.for_range(0, 32, block_num=32) as block_index: | |||
| core_m_idx = block_index // 16 | |||
| @@ -17,9 +17,8 @@ limitations under the License. | |||
| matmul | |||
| """ | |||
| from __future__ import absolute_import | |||
| import te.platform.cce_params as cce | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| import te.platform.cce_params as cce | |||
| from te import tik | |||
| from topi.cce import util | |||
| @@ -141,6 +140,7 @@ src_dtype: str | |||
| def _get_bias(shape_bias): | |||
| """_get_bias""" | |||
| bias_length = shape_bias[0] | |||
| if bias_length % 16 == 0: | |||
| return shape_bias | |||
| @@ -152,6 +152,7 @@ def _get_bias(shape_bias): | |||
| def _get_input_shape(shape_x): | |||
| """_get_input_shape""" | |||
| dim_a = shape_x[0] | |||
| dim_b = shape_x[1] | |||
| res = [] | |||
| @@ -170,6 +171,7 @@ def _get_input_shape(shape_x): | |||
| def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): | |||
| """check_supported""" | |||
| shape_a = input_x1.get("shape") | |||
| shape_b = input_x2.get("shape") | |||
| print("shape_a: ", shape_a) | |||
| @@ -180,8 +182,6 @@ def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, t | |||
| util.check_shape_rule(shape_b) | |||
| util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) | |||
| util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) | |||
| if bias is not None and bool(bias): | |||
| shape_bias = bias.get("shape") | |||
| try: | |||
| trans_a_f = bool(1 - trans_a) | |||
| if src_dtype == "float32" or src_dtype == "int32": | |||
| @@ -265,7 +265,7 @@ def CusMatMulCubeFraczLeftCast(input_x1, input_x2, bias=None, output_y={}, trans | |||
| If True, the input data format of a and b must be fractal format | |||
| shape_bias: list or tuple | |||
| Shape of bias, only support the input data format with ND | |||
| Returns | |||
| ------- | |||
| None | |||
| @@ -381,6 +381,7 @@ def CusMatMulCubeFraczLeftCast(input_x1, input_x2, bias=None, output_y={}, trans | |||
| def get_cus_tile_info(input_x1, input_x2, diag_size): | |||
| """get_cus_tile_info""" | |||
| tile_map = { | |||
| ((32, 32, 16, 16), (128, 32, 16, 16)): (8, 8, 16), | |||
| ((8, 8, 16, 16), (72, 8, 16, 16)): (8, 8, 4), | |||
| @@ -415,8 +416,9 @@ def get_cus_tile_info(input_x1, input_x2, diag_size): | |||
| def cus_cube_matmul_cast(tik_instance, input_x1, trans_a, input_x2, trans_b, | |||
| res, mo_tile, ko_tile, no_tile, diag_opt=False, diag_size=128): | |||
| ko, mo, mi, ki = input_x1.shape | |||
| no, ko, ki, ni = input_x2.shape | |||
| """cus_cube_matmul_cast""" | |||
| ko, mo, _, _ = input_x1.shape | |||
| no, ko, ki, _ = input_x2.shape | |||
| c0 = input_x1.shape[-1] | |||
| diag_outer = diag_size // c0 | |||
| maxblocknum = 32 | |||
| @@ -47,6 +47,7 @@ cus_matmul_cube_fracz_right_mul_op_info = TBERegOp("CusMatMulCubeFraczRightMul") | |||
| @op_info_register(cus_matmul_cube_fracz_right_mul_op_info) | |||
| def CusMatMulCubeFraczRightMul(input_x1, input_x2, input_x3, bias=None, output_y={}, trans_a=False, trans_b=False, | |||
| kernel_name="matmulcube"): | |||
| """CusMatMulCubeFraczRightMul""" | |||
| if util.get_product_version() == util.VERSION_MINI: | |||
| tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) | |||
| else: | |||
| @@ -80,7 +81,7 @@ def CusMatMulCubeFraczRightMul(input_x1, input_x2, input_x3, bias=None, output_y | |||
| ((64, 32, 16, 16), 'float16', (64, 64, 16, 16), 'float16', (1,), 'float32'), | |||
| ((16, 64, 16, 16), 'float16', (16, 16, 16, 16), 'float16', (1,), 'float32')] | |||
| input_shape = ( | |||
| tuple(input_x1_shape), input_x1_dtype, tuple(input_x2_shape), input_x2_dtype, tuple(input_x3_shape), input_x3_dtype) | |||
| tuple(input_x1_shape), input_x1_dtype, tuple(input_x2_shape), input_x2_dtype, tuple(input_x3_shape), input_x3_dtype) | |||
| if input_shape not in Supported: | |||
| raise RuntimeError("input_shape %s is not supported" % str(input_shape)) | |||
| @@ -95,15 +96,17 @@ def CusMatMulCubeFraczRightMul(input_x1, input_x2, input_x3, bias=None, output_y | |||
| def cus_cube_matmul_right_mul(tik_instance, input_x1, input_x2, input_x3, | |||
| res): | |||
| """cus_cube_matmul_right_mul""" | |||
| diag_size = 128 | |||
| ko, mo, mi, ki = input_x1.shape | |||
| no, ko, ki, ni = input_x2.shape | |||
| ko, mo, _, _ = input_x1.shape | |||
| no, ko, ki, _ = input_x2.shape | |||
| c0 = input_x1.shape[-1] | |||
| diag_outer = diag_size // c0 | |||
| if [input_x1.shape[-1], input_x1.shape[-2], input_x2.shape[-1], input_x2.shape[-2]] != [c0, c0, c0, c0]: | |||
| raise ValueError("shape of input_x1 or input_x2 is not supported!") | |||
| def get_cus_tile_info(input_x1, input_x2, input_x3): | |||
| """get_cus_tile_info""" | |||
| input_shape = (tuple(input_x1.shape), input_x1.dtype, tuple(input_x2.shape), input_x2.dtype, | |||
| tuple(input_x3.shape), input_x3.dtype) | |||
| tile_map = { | |||
| @@ -18,11 +18,10 @@ limitations under the License. | |||
| matmul | |||
| """ | |||
| from __future__ import absolute_import | |||
| import te.lang.cce | |||
| import te.platform.cce_params as cce | |||
| from impl.matmul_vector import matmul_vector_cce | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| import te.lang.cce | |||
| import te.platform.cce_params as cce | |||
| from te import tvm | |||
| from topi import generic | |||
| from topi.cce import util | |||
| @@ -146,6 +145,7 @@ def _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b): | |||
| def _get_bias(shape_bias): | |||
| """_get_bias""" | |||
| bias_length = shape_bias[0] | |||
| if bias_length % 16 == 0: | |||
| return shape_bias | |||
| @@ -157,6 +157,7 @@ def _get_bias(shape_bias): | |||
| def _get_input_shape(shape_x): | |||
| """_get_input_shape""" | |||
| dim_a = shape_x[0] | |||
| dim_b = shape_x[1] | |||
| res = [] | |||
| @@ -175,6 +176,7 @@ def _get_input_shape(shape_x): | |||
| def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): | |||
| """check_supported""" | |||
| shape_a = input_x1.get("shape") | |||
| shape_b = input_x2.get("shape") | |||
| print("shape_a: ", shape_a) | |||
| @@ -185,8 +187,6 @@ def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, t | |||
| util.check_shape_rule(shape_b) | |||
| util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) | |||
| util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) | |||
| if bias is not None and bool(bias): | |||
| shape_bias = bias.get("shape") | |||
| try: | |||
| trans_a_f = bool(1 - trans_a) | |||
| if src_dtype == "float32" or src_dtype == "int32": | |||
| @@ -250,7 +250,7 @@ def CusMatMulCube(input_x1, input_x2, bias=None, output_y={}, trans_a=False, tra | |||
| """ | |||
| calculating matrix multiplication with bias, C = A*B + bias, support input | |||
| data with fractal format. | |||
| Parameters: | |||
| shape_a: list or tuple | |||
| Shape of the first tensor a with rank > 1 | |||
| @@ -269,7 +269,7 @@ def CusMatMulCube(input_x1, input_x2, bias=None, output_y={}, trans_a=False, tra | |||
| If True, the input data format of a and b must be fractal format | |||
| shape_bias: list or tuple | |||
| Shape of bias, only support the input data format with ND | |||
| Returns | |||
| ------- | |||
| None | |||
| @@ -32,6 +32,7 @@ cus_transpose02314_op_info = TBERegOp("CusTranspose02314") \ | |||
| @op_info_register(cus_transpose02314_op_info) | |||
| def CusTranspose02314(input_x, output, kernel_name="transpose021354"): | |||
| """CusTranspose02314""" | |||
| input_x_shape = input_x.get("shape") | |||
| output_shape = output.get("shape") | |||
| perm = (0, 2, 3, 1, 4) | |||
| @@ -263,7 +264,7 @@ def CusTranspose02314(input_x, output, kernel_name="transpose021354"): | |||
| with tik_instance.for_range(0, 32, block_num=32) as block_idx: | |||
| with tik_instance.for_range(0, 6, thread_num=2) as cc1: | |||
| _inner_ + compute(cc1) | |||
| _inner_compute(cc1) | |||
| _inner_compute(6) | |||
| elif tuple(input_x_shape) == (32, 64, 14, 14, 16) and tuple(perm) == (0, 2, 3, 1, 4) and dtype == "float16": | |||
| def _inner_compute(split_index, block_idx): | |||
| @@ -91,7 +91,7 @@ class CusFusedAbsMax1(PrimitiveWithInfer): | |||
| def infer_shape(self, data1_shape): | |||
| ll = [] | |||
| if len(data1_shape) == 2: | |||
| ll = [1, ] | |||
| ll = [1,] | |||
| else: | |||
| ll = [32, 64] | |||
| return ll | |||