| @@ -33,5 +33,5 @@ config = ed({ | |||||
| "save_checkpoint_path": "./", | "save_checkpoint_path": "./", | ||||
| "label_smooth": 1, | "label_smooth": 1, | ||||
| "label_smooth_factor": 0.1, | "label_smooth_factor": 0.1, | ||||
| "frequency": 278 | |||||
| "frequency": 834 | |||||
| }) | }) | ||||
| @@ -22,12 +22,6 @@ from mindspore.nn.optim.optimizer import Optimizer | |||||
| from mindspore.ops import functional as F, composite as C, operations as P | from mindspore.ops import functional as F, composite as C, operations as P | ||||
| from mindspore.parallel._utils import _get_device_num, _get_mirror_mean | from mindspore.parallel._utils import _get_device_num, _get_mirror_mean | ||||
| from cus_ops.cus_matmul_cube_dense_right import CusMatMulCubeDenseRight | |||||
| from cus_ops.cus_matmul_cube_fracz_left_cast import CusMatMulCubeFraczLeftCast | |||||
| from cus_ops.cus_matmul_cube_dense_left import CusMatMulCubeDenseLeft | |||||
| from cus_ops.cus_matmul_cube_fracz_right_mul import CusMatMulCubeFraczRightMul | |||||
| from model.grad_reducer_thor import DistributedGradReducerThor | |||||
| momentum_opt = C.MultitypeFuncGraph("momentum_opt") | momentum_opt = C.MultitypeFuncGraph("momentum_opt") | ||||
| @@ -68,10 +62,10 @@ class THOR(Optimizer): | |||||
| self.matrix_G = ParameterTuple(matrix_G) | self.matrix_G = ParameterTuple(matrix_G) | ||||
| self.A_inv_max = ParameterTuple(A_inv_max) | self.A_inv_max = ParameterTuple(A_inv_max) | ||||
| self.G_inv_max = ParameterTuple(G_inv_max) | self.G_inv_max = ParameterTuple(G_inv_max) | ||||
| self.cube_matmul_left = CusMatMulCubeFraczLeftCast() | |||||
| self.cube_matmul_left_fc = CusMatMulCubeDenseLeft() | |||||
| self.cube_matmul_right_fc = CusMatMulCubeDenseRight() | |||||
| self.cube_matmul_right_mul = CusMatMulCubeFraczRightMul() | |||||
| self.cube_matmul_left = P.CusMatMulCubeFraczLeftCast() | |||||
| self.cube_matmul_left_fc = P.CusMatMulCubeDenseLeft() | |||||
| self.cube_matmul_right_fc = P.CusMatMulCubeDenseRight() | |||||
| self.cube_matmul_right_mul = P.CusMatMulCubeFraczRightMul() | |||||
| self.transpose = P.Transpose() | self.transpose = P.Transpose() | ||||
| self.shape = P.Shape() | self.shape = P.Shape() | ||||
| self.reshape = P.Reshape() | self.reshape = P.Reshape() | ||||
| @@ -23,19 +23,9 @@ from mindspore.common.tensor import Tensor | |||||
| from mindspore.nn.cell import Cell | from mindspore.nn.cell import Cell | ||||
| from mindspore.nn.layer.activation import get_activation | from mindspore.nn.layer.activation import get_activation | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from cus_ops.cus_batch_matmul import CusBatchMatMul | |||||
| from cus_ops.cus_cholesky_trsm import CusCholeskyTrsm | |||||
| from cus_ops.cus_fused_abs_max1 import CusFusedAbsMax1 | |||||
| from cus_ops.cus_img2col import CusImg2Col | |||||
| from cus_ops.cus_matmul_cube import CusMatMulCube | |||||
| from cus_ops.cus_matrix_combine import CusMatrixCombine | |||||
| from cus_ops.cus_transpose02314 import CusTranspose02314 | |||||
| import numpy as np | import numpy as np | ||||
| C0 = 16 | C0 = 16 | ||||
| def caculate_device_shape(matrix_dim, channel, is_A): | def caculate_device_shape(matrix_dim, channel, is_A): | ||||
| ll = (0) | ll = (0) | ||||
| if is_A: | if is_A: | ||||
| @@ -153,11 +143,11 @@ class Conv2d_Thor(_Conv): | |||||
| group=self.group | group=self.group | ||||
| ) | ) | ||||
| self.img2col = CusImg2Col(ksizes=ksizes, strides=strides) | |||||
| self.cube_matmul = CusMatMulCube(transpose_a=True) | |||||
| self.matrix_combine = CusMatrixCombine() | |||||
| self.cholesky = CusCholeskyTrsm() | |||||
| self.transpose02314 = CusTranspose02314() | |||||
| self.img2col = P.CusImg2Col(ksizes=ksizes, strides=strides) | |||||
| self.cube_matmul = P.CusMatMulCube(transpose_a=True) | |||||
| self.matrix_combine = P.CusMatrixCombine() | |||||
| self.cholesky = P.CusCholeskyTrsm() | |||||
| self.transpose02314 = P.CusTranspose02314() | |||||
| self.matrix_A_dim = self.in_channels * self.kernel_size[0] * self.kernel_size[1] | self.matrix_A_dim = self.in_channels * self.kernel_size[0] * self.kernel_size[1] | ||||
| self.matrix_G_dim = self.out_channels | self.matrix_G_dim = self.out_channels | ||||
| self.matrix_A_device_shape, self.matrix_A_device_dim = caculate_device_shape(self.matrix_A_dim, | self.matrix_A_device_shape, self.matrix_A_device_dim = caculate_device_shape(self.matrix_A_dim, | ||||
| @@ -190,7 +180,7 @@ class Conv2d_Thor(_Conv): | |||||
| self.mul = P.Mul() | self.mul = P.Mul() | ||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| self.damping = Tensor(damping) | self.damping = Tensor(damping) | ||||
| self.vector_matmul = CusBatchMatMul() | |||||
| self.vector_matmul = P.CusBatchMatMul() | |||||
| self.diag_block_dim = 128 | self.diag_block_dim = 128 | ||||
| self.channels_slice_flag = False | self.channels_slice_flag = False | ||||
| if self.in_channels % C0 != 0: | if self.in_channels % C0 != 0: | ||||
| @@ -221,8 +211,8 @@ class Conv2d_Thor(_Conv): | |||||
| self.dampingA = Tensor(np.identity(dampingA_dim), mstype.float32) | self.dampingA = Tensor(np.identity(dampingA_dim), mstype.float32) | ||||
| self.dampingG = Tensor(np.identity(dampingG_dim), mstype.float32) | self.dampingG = Tensor(np.identity(dampingG_dim), mstype.float32) | ||||
| self.fused_abs_max1 = CusFusedAbsMax1([self.matrix_A_dim, self.matrix_A_dim]) | |||||
| self.fused_abs_max2 = CusFusedAbsMax1() | |||||
| self.fused_abs_max1 = P.CusFusedAbsMax1([self.matrix_A_dim, self.matrix_A_dim]) | |||||
| self.fused_abs_max2 = P.CusFusedAbsMax1() | |||||
| self.log = P.Log() | self.log = P.Log() | ||||
| self.exp = P.Exp() | self.exp = P.Exp() | ||||
| self.sqrt = P.Sqrt() | self.sqrt = P.Sqrt() | ||||
| @@ -375,9 +365,9 @@ class Dense_Thor(Cell): | |||||
| self.fake_G = Tensor(np.zeros([63, 63, 16, 16]).astype(np.float16)) | self.fake_G = Tensor(np.zeros([63, 63, 16, 16]).astype(np.float16)) | ||||
| self.matmul = P.MatMul(transpose_b=True) | self.matmul = P.MatMul(transpose_b=True) | ||||
| self.cube_matmul = CusMatMulCube(transpose_a=True) | |||||
| self.matrix_combine = CusMatrixCombine() | |||||
| self.cholesky = CusCholeskyTrsm() | |||||
| self.cube_matmul = P.CusMatMulCube(transpose_a=True) | |||||
| self.matrix_combine = P.CusMatrixCombine() | |||||
| self.cholesky = P.CusCholeskyTrsm() | |||||
| self.shape = P.Shape() | self.shape = P.Shape() | ||||
| self.reshape = P.Reshape() | self.reshape = P.Reshape() | ||||
| self.transpose = P.Transpose() | self.transpose = P.Transpose() | ||||
| @@ -386,7 +376,7 @@ class Dense_Thor(Cell): | |||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| self.damping = Tensor(damping) | self.damping = Tensor(damping) | ||||
| self.loss_scale = Tensor(1 / loss_scale, mstype.float16) | self.loss_scale = Tensor(1 / loss_scale, mstype.float16) | ||||
| self.vector_matmul = CusBatchMatMul() | |||||
| self.vector_matmul = P.CusBatchMatMul() | |||||
| self.pad = P.Pad(((0, 24), (0, 24))) | self.pad = P.Pad(((0, 24), (0, 24))) | ||||
| self.pad1 = P.Pad(((0, 8), (0, 8))) | self.pad1 = P.Pad(((0, 8), (0, 8))) | ||||
| self.slice = P.Slice() | self.slice = P.Slice() | ||||
| @@ -396,8 +386,8 @@ class Dense_Thor(Cell): | |||||
| self.axis = 0 | self.axis = 0 | ||||
| self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False) | self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False) | ||||
| self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False) | self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False) | ||||
| self.fused_abs_max1 = CusFusedAbsMax1([1000, 1000]) | |||||
| self.fused_abs_max2 = CusFusedAbsMax1() | |||||
| self.fused_abs_max1 = P.CusFusedAbsMax1([1000, 1000]) | |||||
| self.fused_abs_max2 = P.CusFusedAbsMax1() | |||||
| self.log = P.Log() | self.log = P.Log() | ||||
| self.exp = P.Exp() | self.exp = P.Exp() | ||||
| self.dampingA = Tensor(np.identity(2048), mstype.float32) | self.dampingA = Tensor(np.identity(2048), mstype.float32) | ||||
| @@ -33,6 +33,7 @@ cus_batchmatmul_op_info = TBERegOp("CusBatchMatMul") \ | |||||
| def _get_flattern_shape(shape): | def _get_flattern_shape(shape): | ||||
| """_get_flattern_shape""" | |||||
| flattern_shape = 1 | flattern_shape = 1 | ||||
| for dim in shape: | for dim in shape: | ||||
| flattern_shape *= dim | flattern_shape *= dim | ||||
| @@ -40,6 +41,7 @@ def _get_flattern_shape(shape): | |||||
| def _inner_matmul_new(tik_instance, dtype, input1, input1_index, input2, input2_index, res, res_index): | def _inner_matmul_new(tik_instance, dtype, input1, input1_index, input2, input2_index, res, res_index): | ||||
| """_inner_matmul_new""" | |||||
| input_1_local_UB = tik_instance.Tensor(dtype, [128], name="input_1_local_UB", scope=tik.scope_ubuf) | input_1_local_UB = tik_instance.Tensor(dtype, [128], name="input_1_local_UB", scope=tik.scope_ubuf) | ||||
| t_1_0_local_UB = tik_instance.Tensor(dtype, [64 * 128], name="t_1_0_local_UB", scope=tik.scope_ubuf) | t_1_0_local_UB = tik_instance.Tensor(dtype, [64 * 128], name="t_1_0_local_UB", scope=tik.scope_ubuf) | ||||
| tik_instance.data_move(input_1_local_UB, input1[input1_index], 0, 1, 16, 0, 0) | tik_instance.data_move(input_1_local_UB, input1[input1_index], 0, 1, 16, 0, 0) | ||||
| @@ -71,6 +73,7 @@ def _inner_matmul_new(tik_instance, dtype, input1, input1_index, input2, input2_ | |||||
| def _inner_matmul_new_1_64_32_64(tik_instance, dtype, input1, input1_index, input2, input2_index, res, res_index): | def _inner_matmul_new_1_64_32_64(tik_instance, dtype, input1, input1_index, input2, input2_index, res, res_index): | ||||
| """_inner_matmul_new_1_64_32_64""" | |||||
| input_1_local_UB = tik_instance.Tensor(dtype, [64], name="input_1_local_UB", scope=tik.scope_ubuf) | input_1_local_UB = tik_instance.Tensor(dtype, [64], name="input_1_local_UB", scope=tik.scope_ubuf) | ||||
| tik_instance.data_move(input_1_local_UB, input1[input1_index], 0, 1, 8, 0, 0) | tik_instance.data_move(input_1_local_UB, input1[input1_index], 0, 1, 8, 0, 0) | ||||
| with tik_instance.for_range(0, 2, thread_num=2) as thread_idx2: | with tik_instance.for_range(0, 2, thread_num=2) as thread_idx2: | ||||
| @@ -90,6 +93,7 @@ def _inner_matmul_new_1_64_32_64(tik_instance, dtype, input1, input1_index, inpu | |||||
| @op_info_register(cus_batchmatmul_op_info) | @op_info_register(cus_batchmatmul_op_info) | ||||
| def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=True, kernel_name="batchmatmul"): | def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=True, kernel_name="batchmatmul"): | ||||
| """CusBatchMatMul""" | |||||
| if util.get_product_version() == util.VERSION_MINI: | if util.get_product_version() == util.VERSION_MINI: | ||||
| tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) | tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) | ||||
| else: | else: | ||||
| @@ -116,7 +120,6 @@ def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=Tr | |||||
| # if not transpose_a and transpose_b: | # if not transpose_a and transpose_b: | ||||
| batch, m, k = x1_shape | batch, m, k = x1_shape | ||||
| _, n, _ = x2_shape | |||||
| input1_shape = _get_flattern_shape(x1_shape) | input1_shape = _get_flattern_shape(x1_shape) | ||||
| input1 = tik_instance.Tensor(dtype, input1_shape, name="input1", scope=tik.scope_gm) | input1 = tik_instance.Tensor(dtype, input1_shape, name="input1", scope=tik.scope_gm) | ||||
| @@ -32,6 +32,7 @@ cus_cholesky_trsm_op_info = TBERegOp("CusCholeskyTrsm") \ | |||||
| @op_info_register(cus_cholesky_trsm_op_info) | @op_info_register(cus_cholesky_trsm_op_info) | ||||
| def CusCholeskyTrsm(input_x, output, kernel_name): | def CusCholeskyTrsm(input_x, output, kernel_name): | ||||
| """CusCholeskyTrsm""" | |||||
| input_x_shape = input_x.get("shape") | input_x_shape = input_x.get("shape") | ||||
| output_shape = output.get("shape") | output_shape = output.get("shape") | ||||
| split_dim = 128 | split_dim = 128 | ||||
| @@ -33,6 +33,7 @@ cus_fused_abs_max1_op_info = TBERegOp("CusFusedAbsMax1") \ | |||||
| @op_info_register(cus_fused_abs_max1_op_info) | @op_info_register(cus_fused_abs_max1_op_info) | ||||
| def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_max1"): | def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_max1"): | ||||
| """CusFusedAbsMax1""" | |||||
| input_x_shape = input_x.get("shape") | input_x_shape = input_x.get("shape") | ||||
| output_shape = output.get("shape") | output_shape = output.get("shape") | ||||
| @@ -203,9 +204,9 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m | |||||
| tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) | tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) | ||||
| elif (input_x_shape[0] == 8 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( | elif (input_x_shape[0] == 8 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( | ||||
| input_x_shape[0] == 32 and input_x_shape[1] == 16) or ( | input_x_shape[0] == 32 and input_x_shape[1] == 16) or ( | ||||
| input_x_shape[0] == 16 and input_x_shape[1] == 32): | |||||
| input_x_shape[0] == 16 and input_x_shape[1] == 32): | |||||
| if (input_x_shape[0] == 8 and input_x_shape[1] == 128 and input_x_shape[2] == 128) and origin_shape[ | if (input_x_shape[0] == 8 and input_x_shape[1] == 128 and input_x_shape[2] == 128) and origin_shape[ | ||||
| 0] == 1000: | |||||
| 0] == 1000: | |||||
| input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) | input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) | ||||
| res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) | res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) | ||||
| blocks = 32 | blocks = 32 | ||||
| @@ -257,7 +258,7 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m | |||||
| tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) | tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) | ||||
| elif (input_x_shape[0] == 8 and input_x_shape[1] == 128 and input_x_shape[2] == 128) and origin_shape[ | elif (input_x_shape[0] == 8 and input_x_shape[1] == 128 and input_x_shape[2] == 128) and origin_shape[ | ||||
| 0] == 1001: | |||||
| 0] == 1001: | |||||
| input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) | input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) | ||||
| res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) | res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) | ||||
| blocks = 32 | blocks = 32 | ||||
| @@ -350,7 +351,7 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m | |||||
| tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) | tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) | ||||
| elif (input_x_shape[0] == 16 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( | elif (input_x_shape[0] == 16 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( | ||||
| input_x_shape[0] == 16 and input_x_shape[1] == 64) or ( | input_x_shape[0] == 16 and input_x_shape[1] == 64) or ( | ||||
| input_x_shape[0] == 64 and input_x_shape[1] == 16): | |||||
| input_x_shape[0] == 64 and input_x_shape[1] == 16): | |||||
| input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) | input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) | ||||
| res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) | res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) | ||||
| total_elements = 1 | total_elements = 1 | ||||
| @@ -36,6 +36,7 @@ cus_img2col_info = TBERegOp("CusImg2Col") \ | |||||
| @op_info_register(cus_img2col_info) | @op_info_register(cus_img2col_info) | ||||
| def CusImg2Col(input_x, output, ksizes, strides, dilates, mode, kernel_name="img2col"): | def CusImg2Col(input_x, output, ksizes, strides, dilates, mode, kernel_name="img2col"): | ||||
| """CusImg2Col""" | |||||
| input_x_shape = input_x.get("shape") | input_x_shape = input_x.get("shape") | ||||
| input_x_dtype = input_x.get("dtype") | input_x_dtype = input_x.get("dtype") | ||||
| N, C1, H, W, C0 = input_x_shape | N, C1, H, W, C0 = input_x_shape | ||||
| @@ -64,7 +65,7 @@ def CusImg2Col(input_x, output, ksizes, strides, dilates, mode, kernel_name="img | |||||
| ((32, 8, 28, 28, 16), 'float16', (1, 1), (1, 1)), | ((32, 8, 28, 28, 16), 'float16', (1, 1), (1, 1)), | ||||
| ((32, 32, 28, 28, 16), 'float16', (1, 1), (1, 1)), | ((32, 32, 28, 28, 16), 'float16', (1, 1), (1, 1)), | ||||
| ((32, 16, 14, 14, 16), 'float16', (1, 1), (1, 1)), | ((32, 16, 14, 14, 16), 'float16', (1, 1), (1, 1)), | ||||
| ((32, 16, 56, 56, 16), 'float16', (1, 1), (1, 1)), ] | |||||
| ((32, 16, 56, 56, 16), 'float16', (1, 1), (1, 1)),] | |||||
| if input_shape not in supported_shape: | if input_shape not in supported_shape: | ||||
| raise RuntimeError("input_shape %s is not supported" % str(input_shape)) | raise RuntimeError("input_shape %s is not supported" % str(input_shape)) | ||||
| @@ -17,10 +17,9 @@ limitations under the License. | |||||
| matmul | matmul | ||||
| """ | """ | ||||
| from __future__ import absolute_import | from __future__ import absolute_import | ||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| import te.lang.cce | import te.lang.cce | ||||
| import te.platform.cce_params as cce | import te.platform.cce_params as cce | ||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| from te import tik | from te import tik | ||||
| from te import tvm | from te import tvm | ||||
| from topi import generic | from topi import generic | ||||
| @@ -128,7 +127,7 @@ def _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b): | |||||
| if n_shape % cce.BLOCK_IN != 0 and n_shape != 1: | if n_shape % cce.BLOCK_IN != 0 and n_shape != 1: | ||||
| raise RuntimeError("input shape N should be 1 or multiple of %d" % cce.BLOCK_IN) | raise RuntimeError("input shape N should be 1 or multiple of %d" % cce.BLOCK_IN) | ||||
| if len(shape_bias): | |||||
| if len(shape_bias) != 0: | |||||
| if len(shape_bias) == 1: | if len(shape_bias) == 1: | ||||
| if is_gevm or is_gemv: | if is_gevm or is_gemv: | ||||
| if shape_bias[0] != m_shape * n_shape: | if shape_bias[0] != m_shape * n_shape: | ||||
| @@ -145,16 +144,19 @@ def _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b): | |||||
| def _get_bias(shape_bias): | def _get_bias(shape_bias): | ||||
| bias_length = shape_bias[0] | bias_length = shape_bias[0] | ||||
| shb = [] | |||||
| if bias_length % 16 == 0: | if bias_length % 16 == 0: | ||||
| return shape_bias | |||||
| shb = shape_bias | |||||
| else: | else: | ||||
| bias_length = (bias_length // 16) * 16 + 16 | bias_length = (bias_length // 16) * 16 + 16 | ||||
| shape_bias = [] | shape_bias = [] | ||||
| shape_bias.append(bias_length) | shape_bias.append(bias_length) | ||||
| return shape_bias | |||||
| shb = shape_bias | |||||
| return shb | |||||
| def _get_input_shape(shape_x): | def _get_input_shape(shape_x): | ||||
| """_get_input_shape""" | |||||
| dim_a = shape_x[0] | dim_a = shape_x[0] | ||||
| dim_b = shape_x[1] | dim_b = shape_x[1] | ||||
| res = [] | res = [] | ||||
| @@ -173,6 +175,7 @@ def _get_input_shape(shape_x): | |||||
| def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): | def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): | ||||
| """check_supported""" | |||||
| shape_a = input_x1.get("shape") | shape_a = input_x1.get("shape") | ||||
| shape_b = input_x2.get("shape") | shape_b = input_x2.get("shape") | ||||
| print("shape_a: ", shape_a) | print("shape_a: ", shape_a) | ||||
| @@ -183,8 +186,6 @@ def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, t | |||||
| util.check_shape_rule(shape_b) | util.check_shape_rule(shape_b) | ||||
| util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) | util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) | ||||
| util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) | util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) | ||||
| if bias is not None and bool(bias): | |||||
| shape_bias = bias.get("shape") | |||||
| try: | try: | ||||
| trans_a_f = bool(1 - trans_a) | trans_a_f = bool(1 - trans_a) | ||||
| if src_dtype == "float32" or src_dtype == "int32": | if src_dtype == "float32" or src_dtype == "int32": | ||||
| @@ -250,7 +251,7 @@ def CusMatMulCubeDenseLeft(input_x1, input_x2, bias=None, output_y={}, trans_a=F | |||||
| """ | """ | ||||
| calculating matrix multiplication with bias, C = A*B + bias, support input | calculating matrix multiplication with bias, C = A*B + bias, support input | ||||
| data with fractal format. | data with fractal format. | ||||
| Parameters: | Parameters: | ||||
| shape_a: list or tuple | shape_a: list or tuple | ||||
| Shape of the first tensor a with rank > 1 | Shape of the first tensor a with rank > 1 | ||||
| @@ -269,7 +270,7 @@ def CusMatMulCubeDenseLeft(input_x1, input_x2, bias=None, output_y={}, trans_a=F | |||||
| If True, the input data format of a and b must be fractal format | If True, the input data format of a and b must be fractal format | ||||
| shape_bias: list or tuple | shape_bias: list or tuple | ||||
| Shape of bias, only support the input data format with ND | Shape of bias, only support the input data format with ND | ||||
| Returns | Returns | ||||
| ------- | ------- | ||||
| None | None | ||||
| @@ -2,19 +2,19 @@ | |||||
| # -*- coding:utf-8 -*- | # -*- coding:utf-8 -*- | ||||
| """ | """ | ||||
| copyright 2020 Huawei Technologies Co., Ltd | copyright 2020 Huawei Technologies Co., Ltd | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | You may obtain a copy of the License at | ||||
| http://www.apache.org/licenses/LICENSE-2.0 | http://www.apache.org/licenses/LICENSE-2.0 | ||||
| Unless required by applicable law or agreed to in writing, software | Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License == distributed on an "AS IS" BASIS, | distributed under the License == distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | See the License for the specific language governing permissions and | ||||
| limitations under the License. | limitations under the License. | ||||
| matmul | matmul | ||||
| """ | """ | ||||
| from __future__ import absolute_import | from __future__ import absolute_import | ||||
| @@ -43,11 +43,12 @@ matmul_cube_dense_right_op_info = TBERegOp("CusMatMulCubeDenseRight") \ | |||||
| @op_info_register(matmul_cube_dense_right_op_info) | @op_info_register(matmul_cube_dense_right_op_info) | ||||
| def CusMatMulCubeDenseRight(input_x1, input_x2, input_x3, bias=None, output_y={}, trans_a=False, trans_b=False, | def CusMatMulCubeDenseRight(input_x1, input_x2, input_x3, bias=None, output_y={}, trans_a=False, trans_b=False, | ||||
| kernel_name="matmulcube"): | kernel_name="matmulcube"): | ||||
| """CusMatMulCubeDenseRight""" | |||||
| shape_a_temp = (128, 63, 16, 16) | shape_a_temp = (128, 63, 16, 16) | ||||
| shape_b_temp = (128, 128, 16, 16) | shape_b_temp = (128, 128, 16, 16) | ||||
| shape_output = output_y.get("shape") | shape_output = output_y.get("shape") | ||||
| matrix_max_shape = (1,) | matrix_max_shape = (1,) | ||||
| support_shape = [(shape_a_temp, shape_b_temp, matrix_max_shape), ] | |||||
| support_shape = [(shape_a_temp, shape_b_temp, matrix_max_shape),] | |||||
| shape_a_input = input_x1.get("shape") | shape_a_input = input_x1.get("shape") | ||||
| shape_b_input = input_x2.get("shape") | shape_b_input = input_x2.get("shape") | ||||
| matrix_max_input = input_x3.get("shape") | matrix_max_input = input_x3.get("shape") | ||||
| @@ -62,7 +63,7 @@ def CusMatMulCubeDenseRight(input_x1, input_x2, input_x3, bias=None, output_y={} | |||||
| tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) | tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) | ||||
| input_x1 = tik_instance.Tensor("float16", shape_a_temp, name="left_matrix", scope=tik.scope_gm) | input_x1 = tik_instance.Tensor("float16", shape_a_temp, name="left_matrix", scope=tik.scope_gm) | ||||
| input_x2 = tik_instance.Tensor("float16", shape_b_temp, name="right_matrix", scope=tik.scope_gm) | input_x2 = tik_instance.Tensor("float16", shape_b_temp, name="right_matrix", scope=tik.scope_gm) | ||||
| input_x3 = tik_instance.Tensor("float32", [1, ], name="matrix_max", scope=tik.scope_gm) | |||||
| input_x3 = tik_instance.Tensor("float32", [1,], name="matrix_max", scope=tik.scope_gm) | |||||
| resMatmul = tik_instance.Tensor("float32", shape_output, name="output", scope=tik.scope_gm) | resMatmul = tik_instance.Tensor("float32", shape_output, name="output", scope=tik.scope_gm) | ||||
| with tik_instance.for_range(0, 32, block_num=32) as block_index: | with tik_instance.for_range(0, 32, block_num=32) as block_index: | ||||
| core_m_idx = block_index // 16 | core_m_idx = block_index // 16 | ||||
| @@ -17,9 +17,8 @@ limitations under the License. | |||||
| matmul | matmul | ||||
| """ | """ | ||||
| from __future__ import absolute_import | from __future__ import absolute_import | ||||
| import te.platform.cce_params as cce | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | ||||
| import te.platform.cce_params as cce | |||||
| from te import tik | from te import tik | ||||
| from topi.cce import util | from topi.cce import util | ||||
| @@ -141,6 +140,7 @@ src_dtype: str | |||||
| def _get_bias(shape_bias): | def _get_bias(shape_bias): | ||||
| """_get_bias""" | |||||
| bias_length = shape_bias[0] | bias_length = shape_bias[0] | ||||
| if bias_length % 16 == 0: | if bias_length % 16 == 0: | ||||
| return shape_bias | return shape_bias | ||||
| @@ -152,6 +152,7 @@ def _get_bias(shape_bias): | |||||
| def _get_input_shape(shape_x): | def _get_input_shape(shape_x): | ||||
| """_get_input_shape""" | |||||
| dim_a = shape_x[0] | dim_a = shape_x[0] | ||||
| dim_b = shape_x[1] | dim_b = shape_x[1] | ||||
| res = [] | res = [] | ||||
| @@ -170,6 +171,7 @@ def _get_input_shape(shape_x): | |||||
| def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): | def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): | ||||
| """check_supported""" | |||||
| shape_a = input_x1.get("shape") | shape_a = input_x1.get("shape") | ||||
| shape_b = input_x2.get("shape") | shape_b = input_x2.get("shape") | ||||
| print("shape_a: ", shape_a) | print("shape_a: ", shape_a) | ||||
| @@ -180,8 +182,6 @@ def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, t | |||||
| util.check_shape_rule(shape_b) | util.check_shape_rule(shape_b) | ||||
| util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) | util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) | ||||
| util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) | util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) | ||||
| if bias is not None and bool(bias): | |||||
| shape_bias = bias.get("shape") | |||||
| try: | try: | ||||
| trans_a_f = bool(1 - trans_a) | trans_a_f = bool(1 - trans_a) | ||||
| if src_dtype == "float32" or src_dtype == "int32": | if src_dtype == "float32" or src_dtype == "int32": | ||||
| @@ -265,7 +265,7 @@ def CusMatMulCubeFraczLeftCast(input_x1, input_x2, bias=None, output_y={}, trans | |||||
| If True, the input data format of a and b must be fractal format | If True, the input data format of a and b must be fractal format | ||||
| shape_bias: list or tuple | shape_bias: list or tuple | ||||
| Shape of bias, only support the input data format with ND | Shape of bias, only support the input data format with ND | ||||
| Returns | Returns | ||||
| ------- | ------- | ||||
| None | None | ||||
| @@ -381,6 +381,7 @@ def CusMatMulCubeFraczLeftCast(input_x1, input_x2, bias=None, output_y={}, trans | |||||
| def get_cus_tile_info(input_x1, input_x2, diag_size): | def get_cus_tile_info(input_x1, input_x2, diag_size): | ||||
| """get_cus_tile_info""" | |||||
| tile_map = { | tile_map = { | ||||
| ((32, 32, 16, 16), (128, 32, 16, 16)): (8, 8, 16), | ((32, 32, 16, 16), (128, 32, 16, 16)): (8, 8, 16), | ||||
| ((8, 8, 16, 16), (72, 8, 16, 16)): (8, 8, 4), | ((8, 8, 16, 16), (72, 8, 16, 16)): (8, 8, 4), | ||||
| @@ -415,8 +416,9 @@ def get_cus_tile_info(input_x1, input_x2, diag_size): | |||||
| def cus_cube_matmul_cast(tik_instance, input_x1, trans_a, input_x2, trans_b, | def cus_cube_matmul_cast(tik_instance, input_x1, trans_a, input_x2, trans_b, | ||||
| res, mo_tile, ko_tile, no_tile, diag_opt=False, diag_size=128): | res, mo_tile, ko_tile, no_tile, diag_opt=False, diag_size=128): | ||||
| ko, mo, mi, ki = input_x1.shape | |||||
| no, ko, ki, ni = input_x2.shape | |||||
| """cus_cube_matmul_cast""" | |||||
| ko, mo, _, _ = input_x1.shape | |||||
| no, ko, ki, _ = input_x2.shape | |||||
| c0 = input_x1.shape[-1] | c0 = input_x1.shape[-1] | ||||
| diag_outer = diag_size // c0 | diag_outer = diag_size // c0 | ||||
| maxblocknum = 32 | maxblocknum = 32 | ||||
| @@ -47,6 +47,7 @@ cus_matmul_cube_fracz_right_mul_op_info = TBERegOp("CusMatMulCubeFraczRightMul") | |||||
| @op_info_register(cus_matmul_cube_fracz_right_mul_op_info) | @op_info_register(cus_matmul_cube_fracz_right_mul_op_info) | ||||
| def CusMatMulCubeFraczRightMul(input_x1, input_x2, input_x3, bias=None, output_y={}, trans_a=False, trans_b=False, | def CusMatMulCubeFraczRightMul(input_x1, input_x2, input_x3, bias=None, output_y={}, trans_a=False, trans_b=False, | ||||
| kernel_name="matmulcube"): | kernel_name="matmulcube"): | ||||
| """CusMatMulCubeFraczRightMul""" | |||||
| if util.get_product_version() == util.VERSION_MINI: | if util.get_product_version() == util.VERSION_MINI: | ||||
| tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) | tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) | ||||
| else: | else: | ||||
| @@ -80,7 +81,7 @@ def CusMatMulCubeFraczRightMul(input_x1, input_x2, input_x3, bias=None, output_y | |||||
| ((64, 32, 16, 16), 'float16', (64, 64, 16, 16), 'float16', (1,), 'float32'), | ((64, 32, 16, 16), 'float16', (64, 64, 16, 16), 'float16', (1,), 'float32'), | ||||
| ((16, 64, 16, 16), 'float16', (16, 16, 16, 16), 'float16', (1,), 'float32')] | ((16, 64, 16, 16), 'float16', (16, 16, 16, 16), 'float16', (1,), 'float32')] | ||||
| input_shape = ( | input_shape = ( | ||||
| tuple(input_x1_shape), input_x1_dtype, tuple(input_x2_shape), input_x2_dtype, tuple(input_x3_shape), input_x3_dtype) | |||||
| tuple(input_x1_shape), input_x1_dtype, tuple(input_x2_shape), input_x2_dtype, tuple(input_x3_shape), input_x3_dtype) | |||||
| if input_shape not in Supported: | if input_shape not in Supported: | ||||
| raise RuntimeError("input_shape %s is not supported" % str(input_shape)) | raise RuntimeError("input_shape %s is not supported" % str(input_shape)) | ||||
| @@ -95,15 +96,17 @@ def CusMatMulCubeFraczRightMul(input_x1, input_x2, input_x3, bias=None, output_y | |||||
| def cus_cube_matmul_right_mul(tik_instance, input_x1, input_x2, input_x3, | def cus_cube_matmul_right_mul(tik_instance, input_x1, input_x2, input_x3, | ||||
| res): | res): | ||||
| """cus_cube_matmul_right_mul""" | |||||
| diag_size = 128 | diag_size = 128 | ||||
| ko, mo, mi, ki = input_x1.shape | |||||
| no, ko, ki, ni = input_x2.shape | |||||
| ko, mo, _, _ = input_x1.shape | |||||
| no, ko, ki, _ = input_x2.shape | |||||
| c0 = input_x1.shape[-1] | c0 = input_x1.shape[-1] | ||||
| diag_outer = diag_size // c0 | diag_outer = diag_size // c0 | ||||
| if [input_x1.shape[-1], input_x1.shape[-2], input_x2.shape[-1], input_x2.shape[-2]] != [c0, c0, c0, c0]: | if [input_x1.shape[-1], input_x1.shape[-2], input_x2.shape[-1], input_x2.shape[-2]] != [c0, c0, c0, c0]: | ||||
| raise ValueError("shape of input_x1 or input_x2 is not supported!") | raise ValueError("shape of input_x1 or input_x2 is not supported!") | ||||
| def get_cus_tile_info(input_x1, input_x2, input_x3): | def get_cus_tile_info(input_x1, input_x2, input_x3): | ||||
| """get_cus_tile_info""" | |||||
| input_shape = (tuple(input_x1.shape), input_x1.dtype, tuple(input_x2.shape), input_x2.dtype, | input_shape = (tuple(input_x1.shape), input_x1.dtype, tuple(input_x2.shape), input_x2.dtype, | ||||
| tuple(input_x3.shape), input_x3.dtype) | tuple(input_x3.shape), input_x3.dtype) | ||||
| tile_map = { | tile_map = { | ||||
| @@ -18,11 +18,10 @@ limitations under the License. | |||||
| matmul | matmul | ||||
| """ | """ | ||||
| from __future__ import absolute_import | from __future__ import absolute_import | ||||
| import te.lang.cce | |||||
| import te.platform.cce_params as cce | |||||
| from impl.matmul_vector import matmul_vector_cce | from impl.matmul_vector import matmul_vector_cce | ||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | ||||
| import te.lang.cce | |||||
| import te.platform.cce_params as cce | |||||
| from te import tvm | from te import tvm | ||||
| from topi import generic | from topi import generic | ||||
| from topi.cce import util | from topi.cce import util | ||||
| @@ -146,6 +145,7 @@ def _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b): | |||||
| def _get_bias(shape_bias): | def _get_bias(shape_bias): | ||||
| """_get_bias""" | |||||
| bias_length = shape_bias[0] | bias_length = shape_bias[0] | ||||
| if bias_length % 16 == 0: | if bias_length % 16 == 0: | ||||
| return shape_bias | return shape_bias | ||||
| @@ -157,6 +157,7 @@ def _get_bias(shape_bias): | |||||
| def _get_input_shape(shape_x): | def _get_input_shape(shape_x): | ||||
| """_get_input_shape""" | |||||
| dim_a = shape_x[0] | dim_a = shape_x[0] | ||||
| dim_b = shape_x[1] | dim_b = shape_x[1] | ||||
| res = [] | res = [] | ||||
| @@ -175,6 +176,7 @@ def _get_input_shape(shape_x): | |||||
| def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): | def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): | ||||
| """check_supported""" | |||||
| shape_a = input_x1.get("shape") | shape_a = input_x1.get("shape") | ||||
| shape_b = input_x2.get("shape") | shape_b = input_x2.get("shape") | ||||
| print("shape_a: ", shape_a) | print("shape_a: ", shape_a) | ||||
| @@ -185,8 +187,6 @@ def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, t | |||||
| util.check_shape_rule(shape_b) | util.check_shape_rule(shape_b) | ||||
| util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) | util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) | ||||
| util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) | util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) | ||||
| if bias is not None and bool(bias): | |||||
| shape_bias = bias.get("shape") | |||||
| try: | try: | ||||
| trans_a_f = bool(1 - trans_a) | trans_a_f = bool(1 - trans_a) | ||||
| if src_dtype == "float32" or src_dtype == "int32": | if src_dtype == "float32" or src_dtype == "int32": | ||||
| @@ -250,7 +250,7 @@ def CusMatMulCube(input_x1, input_x2, bias=None, output_y={}, trans_a=False, tra | |||||
| """ | """ | ||||
| calculating matrix multiplication with bias, C = A*B + bias, support input | calculating matrix multiplication with bias, C = A*B + bias, support input | ||||
| data with fractal format. | data with fractal format. | ||||
| Parameters: | Parameters: | ||||
| shape_a: list or tuple | shape_a: list or tuple | ||||
| Shape of the first tensor a with rank > 1 | Shape of the first tensor a with rank > 1 | ||||
| @@ -269,7 +269,7 @@ def CusMatMulCube(input_x1, input_x2, bias=None, output_y={}, trans_a=False, tra | |||||
| If True, the input data format of a and b must be fractal format | If True, the input data format of a and b must be fractal format | ||||
| shape_bias: list or tuple | shape_bias: list or tuple | ||||
| Shape of bias, only support the input data format with ND | Shape of bias, only support the input data format with ND | ||||
| Returns | Returns | ||||
| ------- | ------- | ||||
| None | None | ||||
| @@ -32,6 +32,7 @@ cus_transpose02314_op_info = TBERegOp("CusTranspose02314") \ | |||||
| @op_info_register(cus_transpose02314_op_info) | @op_info_register(cus_transpose02314_op_info) | ||||
| def CusTranspose02314(input_x, output, kernel_name="transpose021354"): | def CusTranspose02314(input_x, output, kernel_name="transpose021354"): | ||||
| """CusTranspose02314""" | |||||
| input_x_shape = input_x.get("shape") | input_x_shape = input_x.get("shape") | ||||
| output_shape = output.get("shape") | output_shape = output.get("shape") | ||||
| perm = (0, 2, 3, 1, 4) | perm = (0, 2, 3, 1, 4) | ||||
| @@ -263,7 +264,7 @@ def CusTranspose02314(input_x, output, kernel_name="transpose021354"): | |||||
| with tik_instance.for_range(0, 32, block_num=32) as block_idx: | with tik_instance.for_range(0, 32, block_num=32) as block_idx: | ||||
| with tik_instance.for_range(0, 6, thread_num=2) as cc1: | with tik_instance.for_range(0, 6, thread_num=2) as cc1: | ||||
| _inner_ + compute(cc1) | |||||
| _inner_compute(cc1) | |||||
| _inner_compute(6) | _inner_compute(6) | ||||
| elif tuple(input_x_shape) == (32, 64, 14, 14, 16) and tuple(perm) == (0, 2, 3, 1, 4) and dtype == "float16": | elif tuple(input_x_shape) == (32, 64, 14, 14, 16) and tuple(perm) == (0, 2, 3, 1, 4) and dtype == "float16": | ||||
| def _inner_compute(split_index, block_idx): | def _inner_compute(split_index, block_idx): | ||||
| @@ -91,7 +91,7 @@ class CusFusedAbsMax1(PrimitiveWithInfer): | |||||
| def infer_shape(self, data1_shape): | def infer_shape(self, data1_shape): | ||||
| ll = [] | ll = [] | ||||
| if len(data1_shape) == 2: | if len(data1_shape) == 2: | ||||
| ll = [1, ] | |||||
| ll = [1,] | |||||
| else: | else: | ||||
| ll = [32, 64] | ll = [32, 64] | ||||
| return ll | return ll | ||||