Merge pull request !1442 from zongha/mastertags/v0.5.0-beta
| @@ -31,15 +31,7 @@ config = ed({ | |||
| "save_checkpoint_steps": 5004, | |||
| "keep_checkpoint_max": 20, | |||
| "save_checkpoint_path": "./", | |||
| "lr_init": 0.01, | |||
| "lr_end": 0.00001, | |||
| "lr_max": 0.1, | |||
| "warmup_epochs": 0, | |||
| "lr_decay_mode": "cosine", | |||
| "label_smooth": 1, | |||
| "label_smooth_factor": 0.1, | |||
| "lr": 0.1, | |||
| "T_max": 90, | |||
| "eta_min": 0, | |||
| "frequency": 278 | |||
| "frequency": 834 | |||
| }) | |||
| @@ -22,12 +22,6 @@ from mindspore.nn.optim.optimizer import Optimizer | |||
| from mindspore.ops import functional as F, composite as C, operations as P | |||
| from mindspore.parallel._utils import _get_device_num, _get_mirror_mean | |||
| from cus_ops.cus_matmul_cube_dense_right import CusMatMulCubeDenseRight | |||
| from cus_ops.cus_matmul_cube_fracz_left_cast import CusMatMulCubeFraczLeftCast | |||
| from cus_ops.cus_matmul_cube_dense_left import CusMatMulCubeDenseLeft | |||
| from cus_ops.cus_matmul_cube_fracz_right_mul import CusMatMulCubeFraczRightMul | |||
| from model.grad_reducer_thor import DistributedGradReducerThor | |||
| momentum_opt = C.MultitypeFuncGraph("momentum_opt") | |||
| @@ -68,10 +62,10 @@ class THOR(Optimizer): | |||
| self.matrix_G = ParameterTuple(matrix_G) | |||
| self.A_inv_max = ParameterTuple(A_inv_max) | |||
| self.G_inv_max = ParameterTuple(G_inv_max) | |||
| self.cube_matmul_left = CusMatMulCubeFraczLeftCast() | |||
| self.cube_matmul_left_fc = CusMatMulCubeDenseLeft() | |||
| self.cube_matmul_right_fc = CusMatMulCubeDenseRight() | |||
| self.cube_matmul_right_mul = CusMatMulCubeFraczRightMul() | |||
| self.cube_matmul_left = P.CusMatMulCubeFraczLeftCast() | |||
| self.cube_matmul_left_fc = P.CusMatMulCubeDenseLeft() | |||
| self.cube_matmul_right_fc = P.CusMatMulCubeDenseRight() | |||
| self.cube_matmul_right_mul = P.CusMatMulCubeFraczRightMul() | |||
| self.transpose = P.Transpose() | |||
| self.shape = P.Shape() | |||
| self.reshape = P.Reshape() | |||
| @@ -23,19 +23,9 @@ from mindspore.common.tensor import Tensor | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore.nn.layer.activation import get_activation | |||
| from mindspore.ops import operations as P | |||
| from cus_ops.cus_batch_matmul import CusBatchMatMul | |||
| from cus_ops.cus_cholesky_trsm import CusCholeskyTrsm | |||
| from cus_ops.cus_fused_abs_max1 import CusFusedAbsMax1 | |||
| from cus_ops.cus_img2col import CusImg2Col | |||
| from cus_ops.cus_matmul_cube import CusMatMulCube | |||
| from cus_ops.cus_matrix_combine import CusMatrixCombine | |||
| from cus_ops.cus_transpose02314 import CusTranspose02314 | |||
| import numpy as np | |||
| C0 = 16 | |||
| def caculate_device_shape(matrix_dim, channel, is_A): | |||
| ll = (0) | |||
| if is_A: | |||
| @@ -153,11 +143,11 @@ class Conv2d_Thor(_Conv): | |||
| group=self.group | |||
| ) | |||
| self.img2col = CusImg2Col(ksizes=ksizes, strides=strides) | |||
| self.cube_matmul = CusMatMulCube(transpose_a=True) | |||
| self.matrix_combine = CusMatrixCombine() | |||
| self.cholesky = CusCholeskyTrsm() | |||
| self.transpose02314 = CusTranspose02314() | |||
| self.img2col = P.CusImg2Col(ksizes=ksizes, strides=strides) | |||
| self.cube_matmul = P.CusMatMulCube(transpose_a=True) | |||
| self.matrix_combine = P.CusMatrixCombine() | |||
| self.cholesky = P.CusCholeskyTrsm() | |||
| self.transpose02314 = P.CusTranspose02314() | |||
| self.matrix_A_dim = self.in_channels * self.kernel_size[0] * self.kernel_size[1] | |||
| self.matrix_G_dim = self.out_channels | |||
| self.matrix_A_device_shape, self.matrix_A_device_dim = caculate_device_shape(self.matrix_A_dim, | |||
| @@ -190,7 +180,7 @@ class Conv2d_Thor(_Conv): | |||
| self.mul = P.Mul() | |||
| self.cast = P.Cast() | |||
| self.damping = Tensor(damping) | |||
| self.vector_matmul = CusBatchMatMul() | |||
| self.vector_matmul = P.CusBatchMatMul() | |||
| self.diag_block_dim = 128 | |||
| self.channels_slice_flag = False | |||
| if self.in_channels % C0 != 0: | |||
| @@ -221,8 +211,8 @@ class Conv2d_Thor(_Conv): | |||
| self.dampingA = Tensor(np.identity(dampingA_dim), mstype.float32) | |||
| self.dampingG = Tensor(np.identity(dampingG_dim), mstype.float32) | |||
| self.fused_abs_max1 = CusFusedAbsMax1([self.matrix_A_dim, self.matrix_A_dim]) | |||
| self.fused_abs_max2 = CusFusedAbsMax1() | |||
| self.fused_abs_max1 = P.CusFusedAbsMax1([self.matrix_A_dim, self.matrix_A_dim]) | |||
| self.fused_abs_max2 = P.CusFusedAbsMax1() | |||
| self.log = P.Log() | |||
| self.exp = P.Exp() | |||
| self.sqrt = P.Sqrt() | |||
| @@ -375,9 +365,9 @@ class Dense_Thor(Cell): | |||
| self.fake_G = Tensor(np.zeros([63, 63, 16, 16]).astype(np.float16)) | |||
| self.matmul = P.MatMul(transpose_b=True) | |||
| self.cube_matmul = CusMatMulCube(transpose_a=True) | |||
| self.matrix_combine = CusMatrixCombine() | |||
| self.cholesky = CusCholeskyTrsm() | |||
| self.cube_matmul = P.CusMatMulCube(transpose_a=True) | |||
| self.matrix_combine = P.CusMatrixCombine() | |||
| self.cholesky = P.CusCholeskyTrsm() | |||
| self.shape = P.Shape() | |||
| self.reshape = P.Reshape() | |||
| self.transpose = P.Transpose() | |||
| @@ -386,7 +376,7 @@ class Dense_Thor(Cell): | |||
| self.cast = P.Cast() | |||
| self.damping = Tensor(damping) | |||
| self.loss_scale = Tensor(1 / loss_scale, mstype.float16) | |||
| self.vector_matmul = CusBatchMatMul() | |||
| self.vector_matmul = P.CusBatchMatMul() | |||
| self.pad = P.Pad(((0, 24), (0, 24))) | |||
| self.pad1 = P.Pad(((0, 8), (0, 8))) | |||
| self.slice = P.Slice() | |||
| @@ -396,8 +386,8 @@ class Dense_Thor(Cell): | |||
| self.axis = 0 | |||
| self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False) | |||
| self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False) | |||
| self.fused_abs_max1 = CusFusedAbsMax1([1000, 1000]) | |||
| self.fused_abs_max2 = CusFusedAbsMax1() | |||
| self.fused_abs_max1 = P.CusFusedAbsMax1([1000, 1000]) | |||
| self.fused_abs_max2 = P.CusFusedAbsMax1() | |||
| self.log = P.Log() | |||
| self.exp = P.Exp() | |||
| self.dampingA = Tensor(np.identity(2048), mstype.float32) | |||
| @@ -45,8 +45,7 @@ do | |||
| mkdir ./train_parallel$i | |||
| cp *.py ./train_parallel$i | |||
| cp *.sh ./train_parallel$i | |||
| cp -r second_order ./train_parallel$i/second_order | |||
| cp -r test_ops ./train_parallel$i/test_ops | |||
| cp -r model ./train_parallel$i | |||
| cd ./train_parallel$i || exit | |||
| echo "start training for rank $RANK_ID, device $DEVICE_ID" | |||
| @@ -0,0 +1,257 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """batch_matmul_impl""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| from te import tik | |||
| from topi.cce import util | |||
| cus_batchmatmul_op_info = TBERegOp("CusBatchMatMul") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("batchmatmul.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("CusBatchMatMul") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| def _get_flattern_shape(shape): | |||
| """_get_flattern_shape""" | |||
| flattern_shape = 1 | |||
| for dim in shape: | |||
| flattern_shape *= dim | |||
| return (flattern_shape,) | |||
| def _inner_matmul_new(tik_instance, dtype, input1, input1_index, input2, input2_index, res, res_index): | |||
| """_inner_matmul_new""" | |||
| input_1_local_UB = tik_instance.Tensor(dtype, [128], name="input_1_local_UB", scope=tik.scope_ubuf) | |||
| t_1_0_local_UB = tik_instance.Tensor(dtype, [64 * 128], name="t_1_0_local_UB", scope=tik.scope_ubuf) | |||
| tik_instance.data_move(input_1_local_UB, input1[input1_index], 0, 1, 16, 0, 0) | |||
| with tik_instance.for_range(0, 2) as vec_i: | |||
| tik_instance.vadds(64, t_1_0_local_UB[vec_i * 64], input_1_local_UB[vec_i * 64], 0, 64, 1, 1, 16, 0) | |||
| with tik_instance.for_range(0, 2, thread_num=2) as thread_idx2: | |||
| input_2_local_UB = tik_instance.Tensor(dtype, [64 * 128], name="input_2_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| t_1_local_UB = input_2_local_UB | |||
| bisec_last_axis_local_UB = input_2_local_UB | |||
| matmul_hybrid_f_t_local_UB = tik_instance.Tensor(dtype, [64], name="matmul_hybrid_f_t_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| matmul_hybrid_f_t_local_UB_dst_tmp = tik_instance.Tensor(dtype, [64], | |||
| name="matmul_hybrid_f_t_local_UB_dst_tmp", | |||
| scope=tik.scope_ubuf) | |||
| tik_instance.vector_dup(64, matmul_hybrid_f_t_local_UB, 0, 1, 1, 8) | |||
| tik_instance.data_move(input_2_local_UB, input2[input2_index + thread_idx2 * 8192], 0, 1, 1024, 0, 0) | |||
| tik_instance.vmul(64, t_1_local_UB, t_1_0_local_UB, input_2_local_UB, 128, 1, 1, 1, 8, 8, 8) | |||
| tik_instance.vadd(64, bisec_last_axis_local_UB, t_1_local_UB, t_1_local_UB[64], 64, 1, 1, 1, | |||
| 16, 16, 16) | |||
| tik_instance.vector_dup(64, matmul_hybrid_f_t_local_UB_dst_tmp, 0, 1, 1, 8) | |||
| with tik_instance.for_range(0, 64) as cc6: | |||
| tik_instance.vcadd(64, matmul_hybrid_f_t_local_UB_dst_tmp[cc6], bisec_last_axis_local_UB[cc6 * 128], | |||
| 1, 1, 1, 8) | |||
| tik_instance.vadd(64, matmul_hybrid_f_t_local_UB, matmul_hybrid_f_t_local_UB_dst_tmp, | |||
| matmul_hybrid_f_t_local_UB, 1, 1, 1, 1, 8, 8, 8) | |||
| tik_instance.data_move(res[res_index + thread_idx2 * 64], | |||
| matmul_hybrid_f_t_local_UB, 0, 1, 8, 0, 0) | |||
| def _inner_matmul_new_1_64_32_64(tik_instance, dtype, input1, input1_index, input2, input2_index, res, res_index): | |||
| """_inner_matmul_new_1_64_32_64""" | |||
| input_1_local_UB = tik_instance.Tensor(dtype, [64], name="input_1_local_UB", scope=tik.scope_ubuf) | |||
| tik_instance.data_move(input_1_local_UB, input1[input1_index], 0, 1, 8, 0, 0) | |||
| with tik_instance.for_range(0, 2, thread_num=2) as thread_idx2: | |||
| input_2_local_UB = tik_instance.Tensor(dtype, [32 * 64], name="input_2_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| t_1_local_UB = input_2_local_UB | |||
| matmul_hybrid_f_t_local_UB = tik_instance.Tensor(dtype, [32], name="matmul_hybrid_f_t_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| tik_instance.data_move(input_2_local_UB, input2[input2_index + thread_idx2 * 2048], 0, 1, 256, 0, 0) | |||
| tik_instance.vmul(64, t_1_local_UB, input_1_local_UB, input_2_local_UB, 32, 1, 1, 1, 8, 0, 8) | |||
| with tik_instance.for_range(0, 32) as cc6: | |||
| tik_instance.vcadd(64, matmul_hybrid_f_t_local_UB[cc6], t_1_local_UB[cc6 * 64], | |||
| 1, 1, 1, 8) | |||
| tik_instance.data_move(res[res_index + thread_idx2 * 32], | |||
| matmul_hybrid_f_t_local_UB, 0, 1, 4, 0, 0) | |||
| @op_info_register(cus_batchmatmul_op_info) | |||
| def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=True, kernel_name="batchmatmul"): | |||
| """CusBatchMatMul""" | |||
| if util.get_product_version() == util.VERSION_MINI: | |||
| tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) | |||
| else: | |||
| tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) | |||
| x1_shape = input_x1.get("shape") | |||
| dtype = input_x1.get("dtype").lower() | |||
| x2_shape = input_x2.get("shape") | |||
| if dtype != input_x2.get("dtype").lower(): | |||
| raise RuntimeError("dtype of input_x1 and input_x2 must be same, but got %s vs %s" % ( | |||
| dtype, input_x2.get("dtype").lower())) | |||
| input_shape = (tuple(x1_shape), tuple(x2_shape), dtype, transpose_a, transpose_b) | |||
| support_shape = [((8, 128, 128), (8, 128, 128), "float32", False, True), | |||
| ((36, 128, 128), (36, 128, 128), "float32", False, True), | |||
| ((5, 128, 128), (5, 128, 128), "float32", False, True), | |||
| ((18, 128, 128), (18, 128, 128), "float32", False, True), | |||
| ((16, 128, 128), (16, 128, 128), "float32", False, True), | |||
| ((9, 128, 128), (9, 128, 128), "float32", False, True), | |||
| ((1, 64, 64), (1, 64, 64), "float32", False, True), | |||
| ((1, 128, 128), (1, 128, 128), "float32", False, True), | |||
| ((4, 128, 128), (4, 128, 128), "float32", False, True), | |||
| ((2, 128, 128), (2, 128, 128), "float32", False, True)] | |||
| if input_shape not in support_shape: | |||
| raise RuntimeError("input_shape %s is not supported" % str(input_shape)) | |||
| # if not transpose_a and transpose_b: | |||
| batch, m, k = x1_shape | |||
| input1_shape = _get_flattern_shape(x1_shape) | |||
| input1 = tik_instance.Tensor(dtype, input1_shape, name="input1", scope=tik.scope_gm) | |||
| input2_shape = _get_flattern_shape(x2_shape) | |||
| input2 = tik_instance.Tensor(dtype, input2_shape, name="input2", scope=tik.scope_gm) | |||
| output_shape = x1_shape | |||
| res_shape = _get_flattern_shape(output_shape) | |||
| res = tik_instance.Tensor(dtype, res_shape, name="res", scope=tik.scope_gm) | |||
| if input_shape == ((36, 128, 128), (36, 128, 128), "float32", False, True): | |||
| with tik_instance.for_range(0, 18, block_num=18) as block_idx: | |||
| with tik_instance.for_range(0, 2) as cc0: | |||
| with tik_instance.for_range(0, 128, thread_num=2) as cc1: | |||
| input1_index = block_idx * 32768 + cc0 * 16384 + cc1 * 128 | |||
| input2_index = block_idx * 32768 + cc0 * 16384 | |||
| res_index = block_idx * 32768 + cc0 * 16384 + cc1 * 128 | |||
| _inner_matmul_new(tik_instance, dtype, | |||
| input1, input1_index, | |||
| input2, input2_index, | |||
| res, res_index) | |||
| if input_shape == ((5, 128, 128), (5, 128, 128), "float32", False, True): | |||
| with tik_instance.for_range(0, 30, block_num=30) as block_idx: | |||
| with tik_instance.for_range(0, 11) as cc1_db: | |||
| with tik_instance.for_range(0, 2, thread_num=2) as thread_idx: | |||
| with tik_instance.if_scope(((((block_idx % 6) * 22) + (cc1_db * 2) + thread_idx) < 128)): | |||
| input_1_local_UB = tik_instance.Tensor(dtype, [128], name="input_1_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| t_1_0_local_UB = tik_instance.Tensor(dtype, [64 * 128], name="t_1_0_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| tik_instance.data_move(input_1_local_UB, input1[ | |||
| (block_idx // 6) * 16384 + (block_idx % 6) * 2816 + cc1_db * 256 + thread_idx * 128], 0, 1, | |||
| 16, 0, 0) | |||
| with tik_instance.for_range(0, 2) as vec_i: | |||
| tik_instance.vadds(64, t_1_0_local_UB[vec_i * 64], input_1_local_UB[vec_i * 64], 0, | |||
| 64, 1, 1, 16, 0) | |||
| with tik_instance.for_range(0, 2, thread_num=2) as thread_idx2: | |||
| input_2_local_UB = tik_instance.Tensor(dtype, [64 * 128], name="input_2_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| t_1_local_UB = input_2_local_UB | |||
| bisec_last_axis_local_UB = input_2_local_UB | |||
| matmul_hybrid_f_t_local_UB = tik_instance.Tensor(dtype, [64], | |||
| name="matmul_hybrid_f_t_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| matmul_hybrid_f_t_local_UB_dst_tmp = tik_instance.Tensor(dtype, [64], | |||
| name="matmul_hybrid_f_t_local_UB_dst_tmp", | |||
| scope=tik.scope_ubuf) | |||
| tik_instance.vector_dup(64, matmul_hybrid_f_t_local_UB, 0, 1, 1, 8) | |||
| tik_instance.data_move(input_2_local_UB, | |||
| input2[(block_idx // 6) * 16384 + thread_idx2 * 8192], 0, 1, | |||
| 1024, 0, 0) | |||
| tik_instance.vmul(64, t_1_local_UB, t_1_0_local_UB, input_2_local_UB, 128, 1, 1, 1, 8, 8, 8) | |||
| tik_instance.vadd(64, bisec_last_axis_local_UB, t_1_local_UB, t_1_local_UB[64], 64, 1, 1, 1, | |||
| 16, 16, 16) | |||
| tik_instance.vector_dup(64, matmul_hybrid_f_t_local_UB_dst_tmp, 0, 1, 1, 8) | |||
| with tik_instance.for_range(0, 64) as cc6: | |||
| tik_instance.vcadd(64, matmul_hybrid_f_t_local_UB_dst_tmp[cc6], | |||
| bisec_last_axis_local_UB[cc6 * 128], | |||
| 1, 1, 1, 8) | |||
| tik_instance.vadd(64, matmul_hybrid_f_t_local_UB, matmul_hybrid_f_t_local_UB_dst_tmp, | |||
| matmul_hybrid_f_t_local_UB, 1, 1, 1, 1, 8, 8, 8) | |||
| tik_instance.data_move( | |||
| res[(block_idx // 6) * 16384 + (block_idx % 6) * 2816 + cc1_db * 256 + | |||
| thread_idx * 128 + thread_idx2 * 64], | |||
| matmul_hybrid_f_t_local_UB, 0, 1, 8, 0, 0) | |||
| if input_shape == ((18, 128, 128), (18, 128, 128), "float32", False, True): | |||
| with tik_instance.for_range(0, 18, block_num=18) as block_idx: | |||
| with tik_instance.for_range(0, 128, thread_num=2) as cc0: | |||
| input1_index = block_idx * 16384 + cc0 * 128 | |||
| input2_index = block_idx * 16384 | |||
| res_index = block_idx * 16384 + cc0 * 128 | |||
| _inner_matmul_new(tik_instance, dtype, | |||
| input1, input1_index, | |||
| input2, input2_index, | |||
| res, res_index) | |||
| if input_shape == ((9, 128, 128), (9, 128, 128), "float32", False, True): | |||
| with tik_instance.for_range(0, 27, block_num=27) as block_idx: | |||
| with tik_instance.for_range(0, 42, thread_num=2) as cc0: | |||
| input1_index = (block_idx // 3) * 16384 + (block_idx % 3) * 5504 + cc0 * 128 | |||
| input2_index = (block_idx // 3) * 16384 | |||
| res_index = (block_idx // 3) * 16384 + (block_idx % 3) * 5504 + cc0 * 128 | |||
| _inner_matmul_new(tik_instance, dtype, | |||
| input1, input1_index, | |||
| input2, input2_index, | |||
| res, res_index) | |||
| with tik_instance.if_scope((block_idx % 3) < 2): | |||
| input1_index = (block_idx // 3) * 16384 + (block_idx % 3) * 5504 + 42 * 128 | |||
| input2_index = (block_idx // 3) * 16384 | |||
| res_index = (block_idx // 3) * 16384 + (block_idx % 3) * 5504 + 42 * 128 | |||
| _inner_matmul_new(tik_instance, dtype, | |||
| input1, input1_index, | |||
| input2, input2_index, | |||
| res, res_index) | |||
| if input_shape == ((1, 64, 64), (1, 64, 64), "float32", False, True): | |||
| with tik_instance.for_range(0, 32, block_num=32) as block_idx: | |||
| with tik_instance.for_range(0, 2, thread_num=2) as cc0: | |||
| input1_index = block_idx * 128 + cc0 * 64 | |||
| input2_index = 0 | |||
| res_index = block_idx * 128 + cc0 * 64 | |||
| _inner_matmul_new_1_64_32_64(tik_instance, dtype, | |||
| input1, input1_index, | |||
| input2, input2_index, | |||
| res, res_index) | |||
| input_shape_list = [((1, 128, 128), (1, 128, 128), "float32", False, True), | |||
| ((2, 128, 128), (2, 128, 128), "float32", False, True), | |||
| ((4, 128, 128), (4, 128, 128), "float32", False, True), | |||
| ((8, 128, 128), (8, 128, 128), "float32", False, True), | |||
| ((16, 128, 128), (16, 128, 128), "float32", False, True) | |||
| ] | |||
| if input_shape in input_shape_list: | |||
| block_num = 32 | |||
| input1_unit_size = 128 | |||
| input2_unint_size = 128 * 128 | |||
| with tik_instance.for_range(0, block_num, block_num=block_num) as block_idx: | |||
| block_process_ele_num = (batch * m * k) // block_num | |||
| loop_time = (batch * m * k) // block_num // input1_unit_size | |||
| thread_num = 2 | |||
| with tik_instance.for_range(0, loop_time, thread_num=thread_num) as cc0: | |||
| input1_index = block_idx * block_process_ele_num + cc0 * input1_unit_size | |||
| if batch > 1: | |||
| input2_index = block_idx // (block_num // batch) * input2_unint_size | |||
| else: | |||
| input2_index = 0 | |||
| res_index = block_idx * block_process_ele_num + cc0 * input1_unit_size | |||
| _inner_matmul_new(tik_instance, dtype, | |||
| input1, input1_index, | |||
| input2, input2_index, | |||
| res, res_index) | |||
| tik_instance.BuildCCE(kernel_name, inputs=[input1, input2], outputs=[res]) | |||
| return tik_instance | |||
| @@ -0,0 +1,111 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """CusCholeskyTrsm""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| from te import tik | |||
| from topi.cce import util | |||
| cus_cholesky_trsm_op_info = TBERegOp("CusCholeskyTrsm") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("choleskytrsm.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("CusCholeskyTrsm") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(cus_cholesky_trsm_op_info) | |||
| def CusCholeskyTrsm(input_x, output, kernel_name): | |||
| """CusCholeskyTrsm""" | |||
| input_x_shape = input_x.get("shape") | |||
| output_shape = output.get("shape") | |||
| split_dim = 128 | |||
| matrix_dim = input_x_shape[0] | |||
| split_dim = min(matrix_dim, split_dim) | |||
| vector_repeat_times = int(split_dim // 64) | |||
| blocks = int(matrix_dim // split_dim) | |||
| if blocks == 0: | |||
| blocks = 1 | |||
| if util.get_product_version() == util.VERSION_MINI: | |||
| tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) | |||
| else: | |||
| tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) | |||
| input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) | |||
| res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) | |||
| with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: | |||
| input_x_ub = tik_instance.Tensor("float32", (split_dim, split_dim), name="input_x_ub", scope=tik.scope_ubuf) | |||
| temp_ub = tik_instance.Tensor("float32", (split_dim, split_dim), name="temp_ub", scope=tik.scope_ubuf) | |||
| assist_1_ub = tik_instance.Tensor("float32", (split_dim,), name="assist_1_ub", scope=tik.scope_ubuf) | |||
| assist_2_ub = tik_instance.Tensor("float32", (split_dim,), name="assist_2_ub", scope=tik.scope_ubuf) | |||
| with tik_instance.for_range(0, split_dim) as i: | |||
| tik_instance.data_move(input_x_ub[i, 0], input_x[block_index * split_dim + i, block_index * split_dim], 0, | |||
| 1, vector_repeat_times * 8, 0, 0) | |||
| scalar1 = tik_instance.Scalar("float32", init_value=-0.5) | |||
| with tik_instance.for_range(0, split_dim) as i: | |||
| scalar2 = tik_instance.Scalar("float32") | |||
| tik_instance.vln(64, assist_1_ub[0], input_x_ub[i, 0], vector_repeat_times, 1, 1, 8, 8) | |||
| tik_instance.vmuls(64, assist_2_ub[0], assist_1_ub[0], scalar1, vector_repeat_times, 1, 1, 8, 8) | |||
| tik_instance.vexp(64, assist_1_ub[0], assist_2_ub[0], vector_repeat_times, 1, 1, 8, 8) | |||
| scalar2.set_as(assist_1_ub[i]) | |||
| tik_instance.vmuls(64, input_x_ub[i, 0], input_x_ub[i, 0], scalar2, vector_repeat_times, 1, 1, 8, 8) | |||
| with tik_instance.for_range(i + 1, split_dim) as j: | |||
| scalar3 = tik_instance.Scalar("float32") | |||
| scalar3.set_as(input_x_ub[i, j]) | |||
| tik_instance.vmuls(64, temp_ub[j, 0], input_x_ub[i, 0], scalar3, vector_repeat_times, 1, 1, 8, 8) | |||
| tik_instance.vsub(64, input_x_ub[i + 1, 0], input_x_ub[i + 1, 0], temp_ub[i + 1, 0], | |||
| (split_dim - 1 - i) * vector_repeat_times, 1, 1, 1, 8, 8, 8) | |||
| zero = tik_instance.Scalar("float32") | |||
| zero.set_as(0.0) | |||
| one = tik_instance.Scalar("float32") | |||
| one.set_as(1.0) | |||
| with tik_instance.for_range(0, split_dim) as i: | |||
| tik_instance.vector_dup(64, temp_ub[i, 0], zero, vector_repeat_times, 1, 8) | |||
| temp_ub.__setitem__(i * split_dim + i, one) | |||
| chol_diag_element_final = tik_instance.Scalar("float32") | |||
| chol_diag_element_final.set_as(input_x_ub[split_dim * split_dim - 1]) | |||
| trsm_diag_element = tik_instance.Scalar("float32") | |||
| trsm_diag_element.set_as(1.0 / chol_diag_element_final) | |||
| temp_ub.__setitem__(split_dim * split_dim - 1, trsm_diag_element) | |||
| with tik_instance.for_range(1, split_dim) as i: | |||
| index = split_dim - i - 1 | |||
| tik_instance.vector_dup(64, assist_1_ub, zero, vector_repeat_times, 1, 8) | |||
| with tik_instance.for_range(0, i) as j: | |||
| chol_diag_element_loop = tik_instance.Scalar("float32") | |||
| chol_diag_element_loop.set_as(input_x_ub[index, index + 1 + j]) | |||
| tik_instance.vmuls(64, assist_2_ub, temp_ub[j + index + 1, 0], chol_diag_element_loop, | |||
| vector_repeat_times, 1, 1, 8, 8) | |||
| tik_instance.vadd(64, assist_1_ub, assist_2_ub, assist_1_ub, vector_repeat_times, 1, 1, 1, 8, 8, 8) | |||
| temp_scalar = tik_instance.Scalar("float32") | |||
| temp_scalar.set_as(input_x_ub[index, index]) | |||
| chol_diag_element = tik_instance.Scalar("float32") | |||
| chol_diag_element.set_as(1.0 / temp_scalar) | |||
| tik_instance.vsub(64, temp_ub[index, 0], temp_ub[index, 0], assist_1_ub, vector_repeat_times, 1, 1, 1, 8, 8, | |||
| 8) | |||
| tik_instance.vmuls(64, temp_ub[index, 0], temp_ub[index, 0], chol_diag_element, vector_repeat_times, 1, 1, | |||
| 8, 8) | |||
| tik_instance.data_move(res[block_index, 0, 0], temp_ub, 0, 1, 8 * vector_repeat_times * split_dim, 0, 0) | |||
| tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x], outputs=[res]) | |||
| return tik_instance | |||
| @@ -0,0 +1,468 @@ | |||
| # -*- coding:utf-8 -*- | |||
| """ | |||
| copyright 2020 Huawei Technologies Co., Ltd | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License == distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| matmul | |||
| """ | |||
| from __future__ import absolute_import | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| import te.lang.cce | |||
| import te.platform.cce_params as cce | |||
| from te import tik | |||
| from te import tvm | |||
| from topi import generic | |||
| from topi.cce import util | |||
| # General limitation of the size for input shape: 2**31 | |||
| SHAPE_SIZE_LIMIT = 2147483648 | |||
| NoneType = type(None) | |||
| matmul_cube_dense_left_op_info = TBERegOp("CusMatMulCubeDenseLeft") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("matmulcubedenseleft.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("CusMatMulCubeDenseLeft") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .input(2, "x3", False, "optional", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_FracNZ) \ | |||
| .get_op_info() | |||
| # pylint: disable=locally-disabled,too-many-arguments,too-many-branches, too-many-statements, too-many-locals, | |||
| def _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b): | |||
| """ | |||
| Check the given input if legal | |||
| Parameters: | |||
| shape_a: list or tuple | |||
| Shape of the first tensor a with rank > 1 | |||
| shape_b: list or tuple | |||
| Shape of the second tensor b with the same type with a, | |||
| and shape_a, shape_b must be 2 dims | |||
| shape_bias: list or tuple | |||
| Shape of bias, only support the input data format with ND | |||
| src_dtype: str | |||
| The data type of input, support "float32", "float16" | |||
| trans_a: bool | |||
| If True, shape_a == transposed before multiplication | |||
| trans_b: bool | |||
| If True, shape_b == transposed before multiplication | |||
| Returns None | |||
| """ | |||
| shape_len = len(shape_a) | |||
| src_dtype = src_dtype.lower() | |||
| k_block_size = cce.BLOCK_REDUCE | |||
| check_list = ("float16") | |||
| if src_dtype not in check_list: | |||
| raise RuntimeError("matmul_cce only support %s while src_dtype == %s" | |||
| % (",".join(check_list), src_dtype)) | |||
| if shape_len != len(shape_b): | |||
| raise RuntimeError("length of a and b are not equal") | |||
| if shape_len != 2: | |||
| raise RuntimeError( | |||
| "length of shape must be 2, more than 2 dimensions should use batch_matmul now!") | |||
| is_gevm = True if shape_a[-2] == 1 or shape_a[-1] == 1 else False | |||
| is_gemv = True if shape_b[-2] == 1 or shape_b[-1] == 1 else False | |||
| if trans_a: | |||
| m_shape = shape_a[shape_len - 1] | |||
| km_shape = shape_a[shape_len - 2] | |||
| else: | |||
| m_shape = shape_a[shape_len - 2] | |||
| km_shape = shape_a[shape_len - 1] | |||
| if trans_b: | |||
| kn_shape = shape_b[shape_len - 1] | |||
| n_shape = shape_b[shape_len - 2] | |||
| else: | |||
| kn_shape = shape_b[shape_len - 2] | |||
| n_shape = shape_b[shape_len - 1] | |||
| if m_shape == 1: | |||
| if n_shape == 1: | |||
| raise RuntimeError("input shape M and N can't both be 1") | |||
| if km_shape != kn_shape: | |||
| print(km_shape, kn_shape) | |||
| raise RuntimeError("reduce axis not same") | |||
| if m_shape % cce.BLOCK_IN != 0 and m_shape != 1: | |||
| raise RuntimeError( | |||
| "input shape M should be 1 or multiple of %d" % cce.BLOCK_IN) | |||
| if m_shape != 1: | |||
| if n_shape == 1: | |||
| if km_shape % (cce.BLOCK_IN * cce.BLOCK_IN) != 0: | |||
| raise RuntimeError("input shape K1 should be multiple of %d" | |||
| % (cce.BLOCK_IN * cce.BLOCK_IN)) | |||
| elif km_shape % k_block_size != 0: | |||
| raise RuntimeError( | |||
| "input shape K1 should be multiple of %d" % cce.BLOCK_IN) | |||
| else: | |||
| if km_shape % (cce.BLOCK_IN * cce.BLOCK_IN) != 0: | |||
| raise RuntimeError("input shape K1 should be multiple of %d" | |||
| % (cce.BLOCK_IN * cce.BLOCK_IN)) | |||
| if n_shape % cce.BLOCK_IN != 0 and n_shape != 1: | |||
| raise RuntimeError("input shape N should be 1 or multiple of %d" % cce.BLOCK_IN) | |||
| if len(shape_bias) != 0: | |||
| if len(shape_bias) == 1: | |||
| if is_gevm or is_gemv: | |||
| if shape_bias[0] != m_shape * n_shape: | |||
| raise RuntimeError("broadcast case shape bias for gemv must be equal m*n") | |||
| else: | |||
| if shape_bias[0] != n_shape: | |||
| raise RuntimeError("broadcast bias shape must be equal to shape n") | |||
| elif len(shape_bias) == shape_len: | |||
| if [i for i in shape_bias[-2:]] != [m_shape, n_shape]: | |||
| raise RuntimeError("non broadcast bias shape must be same as output shape") | |||
| else: | |||
| raise RuntimeError("unsupport input shape now for batch bias case") | |||
| def _get_bias(shape_bias): | |||
| """_get_bias""" | |||
| bias_length = shape_bias[0] | |||
| shb = [] | |||
| if bias_length % 16 == 0: | |||
| shb = shape_bias | |||
| else: | |||
| bias_length = (bias_length // 16) * 16 + 16 | |||
| shape_bias = [] | |||
| shape_bias.append(bias_length) | |||
| shb = shape_bias | |||
| return shb | |||
| def _get_input_shape(shape_x): | |||
| """_get_input_shape""" | |||
| dim_a = shape_x[0] | |||
| dim_b = shape_x[1] | |||
| res = [] | |||
| if dim_a % 16 != 0: | |||
| dim_a = (dim_a // 16) * 16 + 16 | |||
| res.append(dim_a) | |||
| else: | |||
| res.append(dim_a) | |||
| if dim_b % 16 != 0: | |||
| dim_b = (dim_b // 16) * 16 + 16 | |||
| res.append(dim_b) | |||
| else: | |||
| res.append(dim_b) | |||
| return res | |||
| def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): | |||
| """check_supported""" | |||
| shape_a = input_x1.get("shape") | |||
| shape_b = input_x2.get("shape") | |||
| print("shape_a: ", shape_a) | |||
| print("shape_b: ", shape_b) | |||
| src_dtype = input_x1.get("dtype") | |||
| util.check_kernel_name(kernel_name) | |||
| util.check_shape_rule(shape_a) | |||
| util.check_shape_rule(shape_b) | |||
| util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) | |||
| util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) | |||
| try: | |||
| trans_a_f = bool(1 - trans_a) | |||
| if src_dtype == "float32" or src_dtype == "int32": | |||
| if len(shape_a) != 2 and len(shape_b) != 2: | |||
| return False | |||
| if trans_b: | |||
| if shape_b[0] == 1: | |||
| return False | |||
| else: | |||
| if shape_b[1] == 1: | |||
| return False | |||
| if trans_a: | |||
| if trans_b: | |||
| if shape_a[0] != shape_b[1]: | |||
| return False | |||
| elif shape_a[0] != shape_b[0]: | |||
| return False | |||
| elif trans_b: | |||
| if shape_a[1] != shape_b[1]: | |||
| return False | |||
| elif shape_a[1] != shape_b[0]: | |||
| return False | |||
| if trans_a_f and trans_b and shape_b[1] == 1: | |||
| return False | |||
| if src_dtype == "float16": | |||
| if len(shape_a) != 2 and len(shape_b) != 2: | |||
| return False | |||
| if trans_a: | |||
| m_shape = shape_a[1] | |||
| k_shape = shape_a[0] | |||
| else: | |||
| m_shape = shape_a[0] | |||
| k_shape = shape_a[1] | |||
| if trans_b: | |||
| n_shape = shape_b[0] | |||
| k_b_shape = shape_b[1] | |||
| else: | |||
| n_shape = shape_b[1] | |||
| k_b_shape = shape_b[0] | |||
| if k_shape != k_b_shape: | |||
| return False | |||
| if m_shape == 1 or n_shape == 1: | |||
| if k_shape % 256 != 0: | |||
| return False | |||
| except RuntimeError as e: | |||
| return False | |||
| return True | |||
| # pylint: disable=locally-disabled,too-many-arguments, too-many-locals, too-many-statements | |||
| # @util.check_input_type(dict, dict, (dict, NoneType), dict, bool, bool, str) | |||
| @op_info_register(matmul_cube_dense_left_op_info) | |||
| def CusMatMulCubeDenseLeft(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, | |||
| kernel_name="matmulcube"): | |||
| """ | |||
| calculating matrix multiplication with bias, C = A*B + bias, support input | |||
| data with fractal format. | |||
| Parameters: | |||
| shape_a: list or tuple | |||
| Shape of the first tensor a with rank > 1 | |||
| shape_b: list or tuple | |||
| Shape of the second tensor b with the same type with a, | |||
| and shape_a, shape_b must be 2 dims | |||
| src_dtype: str | |||
| The data type of input, support "float32", "float16" | |||
| dst_dtype: str | |||
| The data type of output, support "float32", "float16" | |||
| trans_a: bool | |||
| If True, shape_a == transposed before multiplication | |||
| trans_b: bool | |||
| If True, shape_b == transposed before multiplication | |||
| is_fractal: bool | |||
| If True, the input data format of a and b must be fractal format | |||
| shape_bias: list or tuple | |||
| Shape of bias, only support the input data format with ND | |||
| Returns | |||
| ------- | |||
| None | |||
| """ | |||
| print("!!!!come into zzt~~~~~~~!!!!") | |||
| shape_a = input_x1.get("ori_shape") | |||
| shape_b = input_x2.get("ori_shape") | |||
| shape_output = output_y.get("ori_shape") | |||
| print("============") | |||
| print(input_x1.get("format"), input_x2.get("format")) | |||
| print(shape_a, shape_b) | |||
| print("============") | |||
| if input_x2.get("format") == "FRACTAL_Z": | |||
| n, c, h, w = shape_b | |||
| c0 = 16 | |||
| c1 = c // c0 | |||
| if c1 == 0: | |||
| c1 = 1 | |||
| shape_b = [n, c1 * h * w * c0] | |||
| shape_a = [n, n] | |||
| if input_x1.get("format") == "FRACTAL_Z": | |||
| n, c, h, w = shape_a | |||
| c0 = 16 | |||
| c1 = c // c0 | |||
| if c1 == 0: | |||
| c1 = 1 | |||
| shape_a = [n, c1 * h * w * c0] | |||
| shape_b = [c1 * h * w * c0, c1 * h * w * c0] | |||
| if input_x2.get("format") == "FRACTAL_NZ": | |||
| shape_a = [shape_b[0], shape_b[0]] | |||
| shape_b = shape_b | |||
| if input_x1.get("format") == "FRACTAL_NZ": | |||
| shape_a = shape_a | |||
| shape_b = [shape_a[1], shape_a[1]] | |||
| shape_a = list(shape_a) | |||
| shape_b = list(shape_b) | |||
| shape_a = _get_input_shape(shape_a) | |||
| shape_b = _get_input_shape(shape_b) | |||
| util.check_kernel_name(kernel_name) | |||
| util.check_shape_rule(shape_a) | |||
| util.check_shape_rule(shape_b) | |||
| util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) | |||
| util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) | |||
| shape_a = [shape_a[1], shape_a[0]] | |||
| trans_a = bool(1 - trans_a) | |||
| shape_b = [shape_b[1], shape_b[0]] | |||
| trans_b = bool(1 - trans_b) | |||
| shape_bias = () | |||
| if bias is not None and bool(bias): | |||
| shape_bias = bias.get("shape") | |||
| shape_bias = list(shape_bias) | |||
| shape_bias = _get_bias(shape_bias) | |||
| src_dtype = input_x1.get("dtype").lower() | |||
| dst_dtype = output_y.get("dtype").lower() | |||
| _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b) | |||
| m_shape = shape_a[len(shape_a) - 2] | |||
| km_shape = shape_a[len(shape_a) - 1] | |||
| kn_shape = shape_b[len(shape_a) - 2] | |||
| n_shape = shape_b[len(shape_a) - 1] | |||
| if src_dtype == "float16": | |||
| block_reduce = cce.BLOCK_REDUCE | |||
| block_in = cce.BLOCK_IN | |||
| block_out = cce.BLOCK_OUT | |||
| if trans_a and km_shape == 1: | |||
| block_in = cce.BLOCK_VECTOR | |||
| if not trans_a and m_shape == 1: | |||
| block_in = cce.BLOCK_VECTOR | |||
| if trans_b and kn_shape == 1: | |||
| block_out = cce.BLOCK_VECTOR | |||
| if not trans_b and n_shape == 1: | |||
| block_out = cce.BLOCK_VECTOR | |||
| if trans_a: | |||
| shape_a_temp = (m_shape // block_reduce, km_shape // block_in, block_reduce, block_in) | |||
| else: | |||
| shape_a_temp = (m_shape // block_in, km_shape // block_reduce, block_in, block_reduce) | |||
| if trans_b: | |||
| shape_b_temp = (kn_shape // block_out, n_shape // block_reduce, block_reduce, block_out) | |||
| else: | |||
| shape_b_temp = (kn_shape // block_reduce, n_shape // block_out, block_out, block_reduce) | |||
| shape_a_temp = (shape_a_temp[0], shape_a_temp[1], shape_a_temp[2], shape_a_temp[3]) | |||
| format_a = "FRACTAL_NZ" | |||
| shape_b_temp = (shape_b_temp[0], shape_b_temp[1], shape_b_temp[2], shape_b_temp[3]) | |||
| format_b = "FRACTAL_NZ" | |||
| print("=======================================") | |||
| print(shape_a_temp, shape_b_temp) | |||
| print(format_a, format_b) | |||
| print("=======================================") | |||
| tensor_bias = None | |||
| tensor_a = tvm.placeholder(shape_a_temp, name='tensor_a', | |||
| dtype=src_dtype) | |||
| tensor_b = tvm.placeholder(shape_b_temp, name='tensor_b', | |||
| dtype=src_dtype) | |||
| if len(shape_bias) > 0: | |||
| tensor_bias = tvm.placeholder(shape_bias, name='tensor_bias', | |||
| dtype=dst_dtype) | |||
| if shape_a_temp[0] == 63 and shape_a_temp[1] == 63 and shape_b_temp[0] == 128 and shape_b_temp[1] == 63: | |||
| if util.get_product_version() == util.VERSION_MINI: | |||
| tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) | |||
| else: | |||
| tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) | |||
| input_x1 = tik_instance.Tensor("float16", shape_a_temp, name="left_matrix", scope=tik.scope_gm) | |||
| input_x2 = tik_instance.Tensor("float16", shape_b_temp, name="right_matrix", scope=tik.scope_gm) | |||
| resMatmul = tik_instance.Tensor("float16", shape_output, name="output", scope=tik.scope_gm) | |||
| with tik_instance.for_range(0, 32, block_num=32) as block_index: | |||
| resMatmul_local_UB = tik_instance.Tensor("float16", (128 * 256,), scope=tik.scope_ubuf, | |||
| name="resMatmul_local_UB") | |||
| resMatmul_local_UB_local_L0C = tik_instance.Tensor("float32", (128 * 256,), scope=tik.scope_cc, | |||
| name="resMatmul_local_UB") | |||
| input_1_local_L1_local_L0A = tik_instance.Tensor("float16", (128 * 128,), scope=tik.scope_ca, | |||
| name="input_1_local_L1_local_L0A") | |||
| input_2_local_L1 = tik_instance.Tensor("float16", (128 * 256,), scope=tik.scope_cbuf, | |||
| name="input_2_local_L1") | |||
| input_1_local_L1 = tik_instance.Tensor("float16", (128 * 128,), scope=tik.scope_cbuf, | |||
| name="input_1_local_L1") | |||
| input_2_local_L1_local_L0B = tik_instance.Tensor("float16", (128 * 256,), scope=tik.scope_cb, | |||
| name="input_2_local_L1_local_L0B") | |||
| core_m_idx = block_index % 8 | |||
| core_n_idx = block_index // 8 | |||
| with tik_instance.if_scope(core_m_idx != 7): | |||
| tik_instance.data_move(input_1_local_L1, input_x1[core_m_idx * (8 * 256 + 128 * 1008)], 0, 8, 128, | |||
| 55 * 16, 0) | |||
| tik_instance.data_move(input_2_local_L1, input_x2[core_m_idx * 8 * 256 + core_n_idx * 512 * 1008], 0, | |||
| 32, 128, 55 * 16, 0) | |||
| with tik_instance.for_range(0, 8) as cc12: | |||
| tik_instance.load2dv1(input_1_local_L1_local_L0A[cc12 * 2048], input_1_local_L1[cc12 * 256], 0, 8, | |||
| 8, 0, False) | |||
| with tik_instance.for_range(0, 2) as cc6: | |||
| with tik_instance.for_range(0, 8) as cc121: | |||
| tik_instance.load2dv1(input_2_local_L1_local_L0B[cc121 * 4096], | |||
| input_2_local_L1[cc6 * 32768 + cc121 * 256], 0, 16, 8, 0, True) | |||
| tik_instance.mmad(resMatmul_local_UB_local_L0C, input_1_local_L1_local_L0A, | |||
| input_2_local_L1_local_L0B, 128, 128, 256, 0) | |||
| tik_instance.data_move(resMatmul_local_UB, resMatmul_local_UB_local_L0C, 0, 1, 128, 0, 0, 1) | |||
| tik_instance.data_move(resMatmul[cc6 * 256 * 1008 + core_m_idx * 8 * 256 + core_n_idx * 512 * 1008], | |||
| resMatmul_local_UB, 0, 16, 256 // 2, 0, 55 * 16 * 2 // 2) | |||
| with tik_instance.else_scope(): | |||
| tik_instance.data_move(input_1_local_L1, input_x1[core_m_idx * (8 * 256 + 128 * 1008)], 0, 7, 112, | |||
| 56 * 16, 0) | |||
| tik_instance.data_move(input_2_local_L1, input_x2[core_m_idx * 8 * 256 + core_n_idx * 512 * 1008], 0, | |||
| 32, 112, 56 * 16, 0) | |||
| with tik_instance.for_range(0, 7) as cc10: | |||
| tik_instance.load2dv1(input_1_local_L1_local_L0A[cc10 * 1792], input_1_local_L1[cc10 * 256], 0, 7, | |||
| 7, 0, False) | |||
| with tik_instance.for_range(0, 2) as cc5: | |||
| with tik_instance.for_range(0, 7) as cc101: | |||
| tik_instance.load2dv1(input_2_local_L1_local_L0B[cc101 * 4096], | |||
| input_2_local_L1[cc5 * 28672 + cc101 * 256], 0, 16, 7, 0, True) | |||
| tik_instance.mmad(resMatmul_local_UB_local_L0C, input_1_local_L1_local_L0A, | |||
| input_2_local_L1_local_L0B, 112, 112, 256, 0) | |||
| tik_instance.data_move(resMatmul_local_UB, resMatmul_local_UB_local_L0C, 0, 1, 112, 0, 0, 1) | |||
| tik_instance.data_move(resMatmul[cc5 * 256 * 1008 + core_m_idx * 8 * 256 + core_n_idx * 512 * 1008], | |||
| resMatmul_local_UB, 0, 16, 224 // 2, 0, 56 * 16 * 2 // 2) | |||
| tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2], outputs=[resMatmul]) | |||
| return tik_instance | |||
| else: | |||
| print("come into tbe, shape is error!") | |||
| result = te.lang.cce.matmul(tensor_a, tensor_b, trans_a, trans_b, format_a=format_a, | |||
| format_b=format_b, dst_dtype=dst_dtype, tensor_bias=tensor_bias) | |||
| with tvm.target.cce(): | |||
| schedule = generic.auto_schedule(result) | |||
| tensor_list = [tensor_a, tensor_b, result] | |||
| if len(shape_bias) > 0: | |||
| tensor_list = [tensor_a, tensor_b, tensor_bias, result] | |||
| config = {"print_ir": False, | |||
| "name": kernel_name, | |||
| "tensor_list": tensor_list} | |||
| te.lang.cce.cce_build_code(schedule, config) | |||
| @@ -0,0 +1,172 @@ | |||
| #!/usr/bin/env python | |||
| # -*- coding:utf-8 -*- | |||
| """ | |||
| copyright 2020 Huawei Technologies Co., Ltd | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License == distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| matmul | |||
| """ | |||
| from __future__ import absolute_import | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| from te import tik | |||
| from topi.cce import util | |||
| matmul_cube_dense_right_op_info = TBERegOp("CusMatMulCubeDenseRight") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("matmulcubedenseright.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("CusMatMulCubeDenseRight") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .input(2, "x3", False, "required", "all") \ | |||
| .input(3, "x4", False, "optional", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_Default, DataType.F32_Default, DataType.F16_Default, | |||
| DataType.F32_FracNZ) \ | |||
| .get_op_info() | |||
| @op_info_register(matmul_cube_dense_right_op_info) | |||
| def CusMatMulCubeDenseRight(input_x1, input_x2, input_x3, bias=None, output_y={}, trans_a=False, trans_b=False, | |||
| kernel_name="matmulcube"): | |||
| """CusMatMulCubeDenseRight""" | |||
| shape_a_temp = (128, 63, 16, 16) | |||
| shape_b_temp = (128, 128, 16, 16) | |||
| shape_output = output_y.get("shape") | |||
| matrix_max_shape = (1,) | |||
| support_shape = [(shape_a_temp, shape_b_temp, matrix_max_shape),] | |||
| shape_a_input = input_x1.get("shape") | |||
| shape_b_input = input_x2.get("shape") | |||
| matrix_max_input = input_x3.get("shape") | |||
| input_shape = (tuple(shape_a_input), tuple(shape_b_input), tuple(matrix_max_input)) | |||
| if input_shape not in support_shape: | |||
| raise RuntimeError("input_shape %s is not supported" % str(input_shape)) | |||
| if shape_a_temp[0] == 128 and shape_a_temp[1] == 63 and shape_b_temp[0] == 128 and shape_b_temp[1] == 128: | |||
| if util.get_product_version() == util.VERSION_MINI: | |||
| tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) | |||
| else: | |||
| tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) | |||
| input_x1 = tik_instance.Tensor("float16", shape_a_temp, name="left_matrix", scope=tik.scope_gm) | |||
| input_x2 = tik_instance.Tensor("float16", shape_b_temp, name="right_matrix", scope=tik.scope_gm) | |||
| input_x3 = tik_instance.Tensor("float32", [1,], name="matrix_max", scope=tik.scope_gm) | |||
| resMatmul = tik_instance.Tensor("float32", shape_output, name="output", scope=tik.scope_gm) | |||
| with tik_instance.for_range(0, 32, block_num=32) as block_index: | |||
| core_m_idx = block_index // 16 | |||
| core_n_idx = block_index % 16 | |||
| matrix_max_scalar = tik_instance.Scalar("float32") | |||
| matrix_max_local_UB = tik_instance.Tensor("float32", (8,), scope=tik.scope_ubuf, name="matrix_max_local_UB") | |||
| tik_instance.data_move(matrix_max_local_UB, input_x3, 0, 1, 1, 0, 0) | |||
| matrix_max_scalar.set_as(matrix_max_local_UB[0]) | |||
| resMatmul_local_UB = tik_instance.Tensor("float32", (256 * 128,), scope=tik.scope_ubuf, | |||
| name="resMatmul_local_UB") | |||
| resMatmul_local_UB1 = tik_instance.Tensor("float32", (240 * 128,), scope=tik.scope_ubuf, | |||
| name="resMatmul_local_UB1") | |||
| resMatmul_local_UB_local_L0C = tik_instance.Tensor("float32", (256 * 128,), scope=tik.scope_cc, | |||
| name="resMatmul_local_UB_local_L0C") | |||
| resMatmul_local_UB_local_L0C1 = tik_instance.Tensor("float32", (240 * 128,), scope=tik.scope_cc, | |||
| name="resMatmul_local_UB_local_L0C1") | |||
| input_1_local_L1_local_L0A = tik_instance.Tensor("float16", (256 * 128,), scope=tik.scope_ca, | |||
| name="input_1_local_L1_local_L0A") | |||
| input_2_local_L1 = tik_instance.Tensor("float16", (8 * 128 * 16,), scope=tik.scope_cbuf, | |||
| name="input_2_local_L1") | |||
| input_2_local_L11 = tik_instance.Tensor("float16", (8 * 128 * 16,), scope=tik.scope_cbuf, | |||
| name="input_2_local_L11") | |||
| input_1_local_L1 = tik_instance.Tensor("float16", (8 * 256 * 16,), scope=tik.scope_cbuf, | |||
| name="input_1_local_L1") | |||
| input_1_local_L11 = tik_instance.Tensor("float16", (8 * 240 * 16,), scope=tik.scope_cbuf, | |||
| name="input_1_local_L11") | |||
| input_2_local_L1_local_L0B = tik_instance.Tensor("float16", (128 * 128,), scope=tik.scope_cb, | |||
| name="input_2_local_L1_local_L0B") | |||
| input_2_local_L1_local_L0B1 = tik_instance.Tensor("float16", (128 * 128,), scope=tik.scope_cb, | |||
| name="input_2_local_L1_local_L0B1") | |||
| with tik_instance.if_scope(core_m_idx == 0): | |||
| with tik_instance.for_range(0, 2) as cc1: | |||
| tik_instance.data_move(input_2_local_L1, input_x2[core_n_idx * 262144 + core_n_idx * 2048], 0, 8, | |||
| 128, 1920, 0) | |||
| tik_instance.data_move(input_1_local_L1, input_x1[core_n_idx * 129024 + cc1 * 4096], 0, 8, 256, 752, | |||
| 0) | |||
| with tik_instance.for_range(0, 8) as cc10: | |||
| tik_instance.load2dv1(input_2_local_L1_local_L0B[cc10 * 2048], input_2_local_L1[cc10 * 256], 0, | |||
| 8, 8, 0, True) | |||
| with tik_instance.for_range(0, 16) as cc101: | |||
| tik_instance.load2dv1(input_1_local_L1_local_L0A[cc101 * 2048], input_1_local_L1[cc101 * 256], | |||
| 0, 8, 16, 0, False) | |||
| tik_instance.mmad(resMatmul_local_UB_local_L0C, input_1_local_L1_local_L0A, | |||
| input_2_local_L1_local_L0B, 256, 128, 128, 0) | |||
| tik_instance.data_move(resMatmul_local_UB, resMatmul_local_UB_local_L0C, 0, 1, 128, 0, 0) | |||
| tik_instance.vmuls(64, resMatmul_local_UB, resMatmul_local_UB, matrix_max_scalar, 255, 1, 1, 8, 8) | |||
| tik_instance.vmuls(64, resMatmul_local_UB[255 * 64], resMatmul_local_UB[255 * 64], | |||
| matrix_max_scalar, 255, 1, 1, 8, 8) | |||
| tik_instance.vmuls(64, resMatmul_local_UB[510 * 64], resMatmul_local_UB[510 * 64], | |||
| matrix_max_scalar, 2, 1, 1, 8, 8) | |||
| tik_instance.data_move(resMatmul[core_n_idx * 129024 + cc1 * 4096], resMatmul_local_UB, 0, 8, 512, | |||
| 0, 1504) | |||
| with tik_instance.else_scope(): | |||
| tik_instance.data_move(input_2_local_L1, input_x2[core_n_idx * 262144 + core_n_idx * 2048], 0, 8, 128, | |||
| 1920, 0) | |||
| tik_instance.data_move(input_1_local_L1, input_x1[core_n_idx * 129024 + 2 * 4096], 0, 8, 256, 752, 0) | |||
| with tik_instance.for_range(0, 8) as cc10: | |||
| tik_instance.load2dv1(input_2_local_L1_local_L0B[cc10 * 2048], input_2_local_L1[cc10 * 256], 0, 8, | |||
| 8, 0, True) | |||
| with tik_instance.for_range(0, 16) as cc101: | |||
| tik_instance.load2dv1(input_1_local_L1_local_L0A[cc101 * 2048], input_1_local_L1[cc101 * 256], 0, 8, | |||
| 16, 0, False) | |||
| tik_instance.mmad(resMatmul_local_UB_local_L0C, input_1_local_L1_local_L0A, input_2_local_L1_local_L0B, | |||
| 256, 128, 128, 0) | |||
| tik_instance.data_move(resMatmul_local_UB, resMatmul_local_UB_local_L0C, 0, 1, 128, 0, 0) | |||
| tik_instance.vmuls(64, resMatmul_local_UB, resMatmul_local_UB, matrix_max_scalar, 255, 1, 1, 8, 8) | |||
| tik_instance.vmuls(64, resMatmul_local_UB[255 * 64], resMatmul_local_UB[255 * 64], matrix_max_scalar, | |||
| 255, 1, 1, 8, 8) | |||
| tik_instance.vmuls(64, resMatmul_local_UB[510 * 64], resMatmul_local_UB[510 * 64], matrix_max_scalar, 2, | |||
| 1, 1, 8, 8) | |||
| tik_instance.data_move(resMatmul[core_n_idx * 129024 + 2 * 4096], resMatmul_local_UB, 0, 8, 512, 0, | |||
| 1504) | |||
| tik_instance.data_move(input_2_local_L11, input_x2[core_n_idx * 262144 + core_n_idx * 2048], 0, 8, 128, | |||
| 1920, 0) | |||
| tik_instance.data_move(input_1_local_L11, input_x1[core_n_idx * 129024 + 12288], 0, 8, 240, 768, 0) | |||
| with tik_instance.for_range(0, 8) as cc102: | |||
| tik_instance.load2dv1(input_2_local_L1_local_L0B1[cc102 * 2048], input_2_local_L11[cc102 * 256], 0, | |||
| 8, 8, 0, True) | |||
| with tik_instance.for_range(0, 16) as cc103: | |||
| tik_instance.load2dv1(input_1_local_L1_local_L0A[cc103 * 2048], input_1_local_L11[cc103 * 256], 0, | |||
| 8, 15, 0, False) | |||
| tik_instance.mmad(resMatmul_local_UB_local_L0C1, input_1_local_L1_local_L0A, | |||
| input_2_local_L1_local_L0B1, 240, 128, 128, 0) | |||
| tik_instance.data_move(resMatmul_local_UB1, resMatmul_local_UB_local_L0C1, 0, 1, 120, 0, 0) | |||
| tik_instance.vmuls(64, resMatmul_local_UB1, resMatmul_local_UB1, matrix_max_scalar, 255, 1, 1, 8, 8) | |||
| tik_instance.vmuls(64, resMatmul_local_UB1[255 * 64], resMatmul_local_UB1[255 * 64], matrix_max_scalar, | |||
| 225, 1, 1, 8, 8) | |||
| tik_instance.data_move(resMatmul[core_n_idx * 129024 + 12288], resMatmul_local_UB1, 0, 8, 480, 0, 1536) | |||
| tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2, input_x3], outputs=[resMatmul]) | |||
| return tik_instance | |||
| @@ -0,0 +1,526 @@ | |||
| # -*- coding:utf-8 -*- | |||
| """ | |||
| copyright 2020 Huawei Technologies Co., Ltd | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License == distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| matmul | |||
| """ | |||
| from __future__ import absolute_import | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| import te.platform.cce_params as cce | |||
| from te import tik | |||
| from topi.cce import util | |||
| # General limitation of the size for input shape: 2**31 | |||
| SHAPE_SIZE_LIMIT = 2147483648 | |||
| NoneType = type(None) | |||
| matmul_cube_fracz_left_cast_op_info = TBERegOp("CusMatMulCubeFraczLeftCast") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("matmulcubefraczleftcast.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("CusMatMulCubeFraczLeftCast") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .input(2, "x3", False, "optional", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F32_FracZ, DataType.F16_Default, DataType.F16_FracZ) \ | |||
| .get_op_info() | |||
| # pylint: disable=locally-disabled,too-many-arguments,too-many-branches, too-many-statements, too-many-locals, | |||
| def _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b): | |||
| """ | |||
| Check the given input if legal | |||
| Parameters: | |||
| shape_a: list or tuple | |||
| Shape of the first tensor a with rank > 1 | |||
| shape_b: list or tuple | |||
| Shape of the second tensor b with the same type with a, | |||
| and shape_a, shape_b must be 2 dims | |||
| shape_bias: list or tuple | |||
| Shape of bias, only support the input data format with ND | |||
| src_dtype: str | |||
| The data type of input, support "float32", "float16" | |||
| trans_a: bool | |||
| If True, shape_a == transposed before multiplication | |||
| trans_b: bool | |||
| If True, shape_b == transposed before multiplication | |||
| Returns None | |||
| """ | |||
| shape_len = len(shape_a) | |||
| src_dtype = src_dtype.lower() | |||
| k_block_size = cce.BLOCK_REDUCE | |||
| check_list = ("float16") | |||
| if src_dtype not in check_list: | |||
| raise RuntimeError("matmul_cce only support %s while src_dtype == %s" | |||
| % (",".join(check_list), src_dtype)) | |||
| if shape_len != len(shape_b): | |||
| raise RuntimeError("length of a and b are not equal") | |||
| if shape_len != 2: | |||
| raise RuntimeError( | |||
| "length of shape must be 2, more than 2 dimensions should use batch_matmul now!") | |||
| is_gevm = True if shape_a[-2] == 1 or shape_a[-1] == 1 else False | |||
| is_gemv = True if shape_b[-2] == 1 or shape_b[-1] == 1 else False | |||
| if trans_a: | |||
| m_shape = shape_a[shape_len - 1] | |||
| km_shape = shape_a[shape_len - 2] | |||
| else: | |||
| m_shape = shape_a[shape_len - 2] | |||
| km_shape = shape_a[shape_len - 1] | |||
| if trans_b: | |||
| kn_shape = shape_b[shape_len - 1] | |||
| n_shape = shape_b[shape_len - 2] | |||
| else: | |||
| kn_shape = shape_b[shape_len - 2] | |||
| n_shape = shape_b[shape_len - 1] | |||
| if m_shape == 1: | |||
| if n_shape == 1: | |||
| raise RuntimeError("input shape M and N can't both be 1") | |||
| if km_shape != kn_shape: | |||
| print(km_shape, kn_shape) | |||
| raise RuntimeError("reduce axis not same") | |||
| if m_shape % cce.BLOCK_IN != 0 and m_shape != 1: | |||
| raise RuntimeError( | |||
| "input shape M should be 1 or multiple of %d" % cce.BLOCK_IN) | |||
| if m_shape != 1: | |||
| if n_shape == 1: | |||
| if km_shape % (cce.BLOCK_IN * cce.BLOCK_IN) != 0: | |||
| raise RuntimeError("input shape K1 should be multiple of %d" | |||
| % (cce.BLOCK_IN * cce.BLOCK_IN)) | |||
| elif km_shape % k_block_size != 0: | |||
| raise RuntimeError( | |||
| "input shape K1 should be multiple of %d" % cce.BLOCK_IN) | |||
| else: | |||
| if km_shape % (cce.BLOCK_IN * cce.BLOCK_IN) != 0: | |||
| raise RuntimeError("input shape K1 should be multiple of %d" | |||
| % (cce.BLOCK_IN * cce.BLOCK_IN)) | |||
| if n_shape % cce.BLOCK_IN != 0 and n_shape != 1: | |||
| raise RuntimeError("input shape N should be 1 or multiple of %d" % cce.BLOCK_IN) | |||
| if len(shape_bias): | |||
| if len(shape_bias) == 1: | |||
| if is_gevm or is_gemv: | |||
| if shape_bias[0] != m_shape * n_shape: | |||
| raise RuntimeError("broadcast case shape bias for gemv must be equal m*n") | |||
| else: | |||
| if shape_bias[0] != n_shape: | |||
| raise RuntimeError("broadcast bias shape must be equal to shape n") | |||
| elif len(shape_bias) == shape_len: | |||
| if [i for i in shape_bias[-2:]] != [m_shape, n_shape]: | |||
| raise RuntimeError("non broadcast bias shape must be same as output shape") | |||
| else: | |||
| raise RuntimeError("unsupport input shape now for batch bias case") | |||
| def _get_bias(shape_bias): | |||
| """_get_bias""" | |||
| bias_length = shape_bias[0] | |||
| if bias_length % 16 == 0: | |||
| return shape_bias | |||
| else: | |||
| bias_length = (bias_length // 16) * 16 + 16 | |||
| shape_bias = [] | |||
| shape_bias.append(bias_length) | |||
| return shape_bias | |||
| def _get_input_shape(shape_x): | |||
| """_get_input_shape""" | |||
| dim_a = shape_x[0] | |||
| dim_b = shape_x[1] | |||
| res = [] | |||
| if dim_a % 16 != 0: | |||
| dim_a = (dim_a // 16) * 16 + 16 | |||
| res.append(dim_a) | |||
| else: | |||
| res.append(dim_a) | |||
| if dim_b % 16 != 0: | |||
| dim_b = (dim_b // 16) * 16 + 16 | |||
| res.append(dim_b) | |||
| else: | |||
| res.append(dim_b) | |||
| return res | |||
| def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): | |||
| """check_supported""" | |||
| shape_a = input_x1.get("shape") | |||
| shape_b = input_x2.get("shape") | |||
| print("shape_a: ", shape_a) | |||
| print("shape_b: ", shape_b) | |||
| src_dtype = input_x1.get("dtype") | |||
| util.check_kernel_name(kernel_name) | |||
| util.check_shape_rule(shape_a) | |||
| util.check_shape_rule(shape_b) | |||
| util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) | |||
| util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) | |||
| try: | |||
| trans_a_f = bool(1 - trans_a) | |||
| if src_dtype == "float32" or src_dtype == "int32": | |||
| if len(shape_a) != 2 and len(shape_b) != 2: | |||
| return False | |||
| if trans_b: | |||
| if shape_b[0] == 1: | |||
| return False | |||
| else: | |||
| if shape_b[1] == 1: | |||
| return False | |||
| if trans_a: | |||
| if trans_b: | |||
| if shape_a[0] != shape_b[1]: | |||
| return False | |||
| elif shape_a[0] != shape_b[0]: | |||
| return False | |||
| elif trans_b: | |||
| if shape_a[1] != shape_b[1]: | |||
| return False | |||
| elif shape_a[1] != shape_b[0]: | |||
| return False | |||
| if trans_a_f and trans_b and shape_b[1] == 1: | |||
| return False | |||
| if src_dtype == "float16": | |||
| if len(shape_a) != 2 and len(shape_b) != 2: | |||
| return False | |||
| if trans_a: | |||
| m_shape = shape_a[1] | |||
| k_shape = shape_a[0] | |||
| else: | |||
| m_shape = shape_a[0] | |||
| k_shape = shape_a[1] | |||
| if trans_b: | |||
| n_shape = shape_b[0] | |||
| k_b_shape = shape_b[1] | |||
| else: | |||
| n_shape = shape_b[1] | |||
| k_b_shape = shape_b[0] | |||
| if k_shape != k_b_shape: | |||
| return False | |||
| if m_shape == 1 or n_shape == 1: | |||
| if k_shape % 256 != 0: | |||
| return False | |||
| except RuntimeError as e: | |||
| return False | |||
| return True | |||
| # pylint: disable=locally-disabled,too-many-arguments, too-many-locals, too-many-statements | |||
| @op_info_register(matmul_cube_fracz_left_cast_op_info) | |||
| def CusMatMulCubeFraczLeftCast(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, | |||
| kernel_name="CusMatMulCubeFraczLeftCast"): | |||
| """ | |||
| calculating matrix multiplication with bias, C = A*B + bias, support input | |||
| data with fractal format. | |||
| Parameters: | |||
| shape_a: list or tuple | |||
| Shape of the first tensor a with rank > 1 | |||
| shape_b: list or tuple | |||
| Shape of the second tensor b with the same type with a, | |||
| and shape_a, shape_b must be 2 dims | |||
| src_dtype: str | |||
| The data type of input, support "float32", "float16" | |||
| dst_dtype: str | |||
| The data type of output, support "float32", "float16" | |||
| trans_a: bool | |||
| If True, shape_a == transposed before multiplication | |||
| trans_b: bool | |||
| If True, shape_b == transposed before multiplication | |||
| is_fractal: bool | |||
| If True, the input data format of a and b must be fractal format | |||
| shape_bias: list or tuple | |||
| Shape of bias, only support the input data format with ND | |||
| Returns | |||
| ------- | |||
| None | |||
| """ | |||
| shape_a = input_x1.get("ori_shape") | |||
| shape_b = input_x2.get("ori_shape") | |||
| print("============") | |||
| print(input_x1.get("format"), input_x2.get("format")) | |||
| print(shape_a, shape_b) | |||
| print("============") | |||
| if input_x2.get("format") == "FRACTAL_Z": | |||
| n, c, h, w = shape_b | |||
| c0 = 16 | |||
| c1 = c // c0 | |||
| if c1 == 0: | |||
| c1 = 1 | |||
| shape_b = [n, c1 * h * w * c0] | |||
| shape_a = [n, n] | |||
| if input_x1.get("format") == "FRACTAL_Z": | |||
| n, c, h, w = shape_a | |||
| c0 = 16 | |||
| c1 = c // c0 | |||
| if c1 == 0: | |||
| c1 = 1 | |||
| shape_a = [n, c1 * h * w * c0] | |||
| shape_b = [c1 * h * w * c0, c1 * h * w * c0] | |||
| if input_x2.get("format") == "FRACTAL_NZ": | |||
| shape_a = [shape_b[0], shape_b[0]] | |||
| shape_b = shape_b | |||
| if input_x1.get("format") == "FRACTAL_NZ": | |||
| shape_a = shape_a | |||
| shape_b = [shape_a[1], shape_a[1]] | |||
| shape_a = list(shape_a) | |||
| shape_b = list(shape_b) | |||
| shape_a = _get_input_shape(shape_a) | |||
| shape_b = _get_input_shape(shape_b) | |||
| util.check_kernel_name(kernel_name) | |||
| util.check_shape_rule(shape_a) | |||
| util.check_shape_rule(shape_b) | |||
| util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) | |||
| util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) | |||
| shape_a = [shape_a[1], shape_a[0]] | |||
| trans_a = bool(1 - trans_a) | |||
| shape_b = [shape_b[1], shape_b[0]] | |||
| trans_b = bool(1 - trans_b) | |||
| shape_bias = () | |||
| if bias is not None and bool(bias): | |||
| shape_bias = bias.get("shape") | |||
| shape_bias = list(shape_bias) | |||
| shape_bias = _get_bias(shape_bias) | |||
| src_dtype = input_x1.get("dtype").lower() | |||
| _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b) | |||
| m_shape = shape_a[len(shape_a) - 2] | |||
| km_shape = shape_a[len(shape_a) - 1] | |||
| kn_shape = shape_b[len(shape_a) - 2] | |||
| n_shape = shape_b[len(shape_a) - 1] | |||
| if src_dtype == "float16": | |||
| block_reduce = cce.BLOCK_REDUCE | |||
| block_in = cce.BLOCK_IN | |||
| block_out = cce.BLOCK_OUT | |||
| if trans_a and km_shape == 1: | |||
| block_in = cce.BLOCK_VECTOR | |||
| if not trans_a and m_shape == 1: | |||
| block_in = cce.BLOCK_VECTOR | |||
| if trans_b and kn_shape == 1: | |||
| block_out = cce.BLOCK_VECTOR | |||
| if not trans_b and n_shape == 1: | |||
| block_out = cce.BLOCK_VECTOR | |||
| if trans_a: | |||
| shape_a_temp = (m_shape // block_reduce, km_shape // block_in, block_reduce, block_in) | |||
| else: | |||
| shape_a_temp = (m_shape // block_in, km_shape // block_reduce, block_in, block_reduce) | |||
| if trans_b: | |||
| shape_b_temp = (kn_shape // block_out, n_shape // block_reduce, block_reduce, block_out) | |||
| else: | |||
| shape_b_temp = (kn_shape // block_reduce, n_shape // block_out, block_out, block_reduce) | |||
| shape_a_temp = (shape_a_temp[0], shape_a_temp[1], shape_a_temp[2], shape_a_temp[3]) | |||
| shape_b_temp = (shape_b_temp[0], shape_b_temp[1], shape_b_temp[2], shape_b_temp[3]) | |||
| if util.get_product_version() == util.VERSION_MINI: | |||
| tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) | |||
| else: | |||
| tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) | |||
| input_x1 = tik_instance.Tensor(input_x1.get("dtype"), shape_a_temp, name="left_matrix", scope=tik.scope_gm) | |||
| input_x2 = tik_instance.Tensor(input_x2.get("dtype"), shape_b_temp, name="right_matrix", scope=tik.scope_gm) | |||
| res_matmul = tik_instance.Tensor(output_y.get("dtype"), output_y.get("shape"), name="output", scope=tik.scope_gm) | |||
| DIAG_SIZE = 128 | |||
| mo_tile, ko_tile, no_tile, diag_opt = get_cus_tile_info(input_x1, input_x2, DIAG_SIZE) | |||
| cus_cube_matmul_cast(tik_instance, input_x1, trans_a, input_x2, trans_b, res_matmul, | |||
| mo_tile=mo_tile, ko_tile=ko_tile, no_tile=no_tile, | |||
| diag_opt=diag_opt, diag_size=DIAG_SIZE) | |||
| tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2], outputs=[res_matmul]) | |||
| return tik_instance | |||
| def get_cus_tile_info(input_x1, input_x2, diag_size): | |||
| """get_cus_tile_info""" | |||
| tile_map = { | |||
| ((32, 32, 16, 16), (128, 32, 16, 16)): (8, 8, 16), | |||
| ((8, 8, 16, 16), (72, 8, 16, 16)): (8, 8, 4), | |||
| ((32, 32, 16, 16), (288, 32, 16, 16)): (8, 8, 12), | |||
| ((128, 128, 16, 16), (32, 128, 16, 16)): (8, 8, 16), | |||
| ((16, 16, 16, 16), (144, 16, 16, 16)): (8, 8, 9), | |||
| ((64, 64, 16, 16), (16, 64, 16, 16)): (8, 8, 4), | |||
| ((16, 16, 16, 16), (64, 16, 16, 16)): (8, 8, 4), | |||
| ((32, 32, 16, 16), (8, 32, 16, 16)): (8, 8, 1), | |||
| ((128, 128, 16, 16), (64, 128, 16, 16)): (8, 8, 16), | |||
| ((16, 16, 16, 16), (4, 16, 16, 16)): (8, 8, 1), | |||
| ((16, 16, 16, 16), (32, 16, 16, 16)): (8, 8, 2), | |||
| ((64, 64, 16, 16), (32, 64, 16, 16)): (8, 8, 8), | |||
| ((32, 32, 16, 16), (64, 32, 16, 16)): (8, 8, 8), | |||
| ((32, 32, 16, 16), (16, 32, 16, 16)): (8, 8, 2), | |||
| ((8, 8, 16, 16), (32, 8, 16, 16)): (8, 8, 1), | |||
| ((8, 8, 16, 16), (16, 8, 16, 16)): (4, 8, 1), | |||
| ((4, 4, 16, 16), (16, 4, 16, 16)): (2, 4, 1), | |||
| ((4, 4, 16, 16), (4, 4, 16, 16)): (1, 4, 1), | |||
| ((4, 4, 16, 16), (36, 4, 16, 16)): (2, 4, 3), | |||
| ((4, 4, 16, 16), (49, 4, 16, 16)): (1, 4, 7) | |||
| } | |||
| shape_info = (tuple(input_x1.shape), tuple(input_x2.shape)) | |||
| diag_opt = False | |||
| if input_x1.shape[0] * input_x1.shape[3] > diag_size: | |||
| diag_opt = True | |||
| if shape_info not in tile_map: | |||
| raise ValueError("shape %s is not supported" % str(shape_info)) | |||
| mo_tile, ko_tile, no_tile = tile_map[shape_info] | |||
| return mo_tile, ko_tile, no_tile, diag_opt | |||
| def cus_cube_matmul_cast(tik_instance, input_x1, trans_a, input_x2, trans_b, | |||
| res, mo_tile, ko_tile, no_tile, diag_opt=False, diag_size=128): | |||
| """cus_cube_matmul_cast""" | |||
| ko, mo, _, _ = input_x1.shape | |||
| no, ko, _, _ = input_x2.shape | |||
| c0 = input_x1.shape[-1] | |||
| diag_outer = diag_size // c0 | |||
| maxblocknum = 32 | |||
| fp32_size = 4 | |||
| fp16_size = 2 | |||
| blocksize = 32 | |||
| vectorfp32_size = 64 | |||
| if [input_x1.shape[-1], input_x1.shape[-2], input_x2.shape[-1], input_x2.shape[-2]] != [c0, c0, c0, c0]: | |||
| raise ValueError("shape of input_x1 or input_x2 is not supported!") | |||
| if not trans_a or not trans_b: | |||
| raise ValueError("only trans_a=False and trans_b=False be supported!") | |||
| core_m_num = mo // mo_tile | |||
| loop_n_num = no // no_tile | |||
| if loop_n_num * core_m_num <= maxblocknum: | |||
| core_n_num = loop_n_num | |||
| else: | |||
| core_n_num = maxblocknum // core_m_num | |||
| if core_n_num > 0 and loop_n_num % core_n_num == 0: | |||
| loop_n_num = loop_n_num // core_n_num | |||
| else: | |||
| raise ValueError("Does not support this scenario!") | |||
| block_num = core_m_num * core_n_num | |||
| loop_k_num = ko // ko_tile | |||
| if diag_opt: | |||
| loop_k_num = diag_outer // ko_tile | |||
| # double buffer: | |||
| thread_num_k = 2 | |||
| loop_k_num *= thread_num_k | |||
| ko_tile_inner = ko_tile // thread_num_k | |||
| with tik_instance.for_range(0, block_num, block_num=block_num) as block_idx: | |||
| core_m = block_idx // core_n_num | |||
| core_n = block_idx % core_n_num | |||
| with tik_instance.for_range(0, loop_n_num) as cc_n: | |||
| res_L0C = tik_instance.Tensor("float32", [no_tile, mo_tile, c0, c0], | |||
| name="resMatmul_L0C", scope=tik.scope_cc) | |||
| with tik_instance.for_range(0, loop_k_num, thread_num=thread_num_k) as thread_idx_k: | |||
| # input_x2 -> input_x2_ub -(fp322fp16)-> input_x2_cast_ub -> input_x2_L1 | |||
| input_x2_ub = tik_instance.Tensor("float32", [no_tile, ko_tile_inner, c0, c0], name="input_x2_ub", | |||
| scope=tik.scope_ubuf) | |||
| if diag_opt: | |||
| k_idx = core_m * mo_tile + thread_idx_k * ko_tile_inner | |||
| else: | |||
| k_idx = thread_idx_k * ko_tile_inner | |||
| tik_instance.data_move(input_x2_ub, | |||
| input_x2[(core_n * loop_n_num + cc_n) * no_tile, | |||
| k_idx, 0, 0], | |||
| 0, no_tile, ko_tile_inner * c0 * c0 * fp32_size // blocksize, | |||
| (ko - ko_tile_inner) * c0 * c0 * fp32_size // blocksize, 0) | |||
| input_x2_cast_ub = tik_instance.Tensor("float16", [no_tile, ko_tile_inner, c0, c0], | |||
| name="input_x2_cast_ub", scope=tik.scope_ubuf) | |||
| repeate_num = no_tile * ko_tile_inner * c0 * c0 // vectorfp32_size | |||
| repeate_times_max = 255 | |||
| count = 0 | |||
| while repeate_num > repeate_times_max: | |||
| tik_instance.vconv(vectorfp32_size, 'none', | |||
| input_x2_cast_ub[count * repeate_times_max * vectorfp32_size], | |||
| input_x2_ub[count * repeate_times_max * vectorfp32_size], | |||
| repeate_times_max, | |||
| 1, 1, 4, 8) | |||
| repeate_num -= repeate_times_max | |||
| count += 1 | |||
| tik_instance.vconv(vectorfp32_size, 'none', | |||
| input_x2_cast_ub[count * repeate_times_max * vectorfp32_size], | |||
| input_x2_ub[count * repeate_times_max * vectorfp32_size], repeate_num, | |||
| 1, 1, 4, 8) | |||
| input_x2_L1 = tik_instance.Tensor("float16", [no_tile, ko_tile_inner, c0, c0], | |||
| name="input_x2_L1", scope=tik.scope_cbuf) | |||
| tik_instance.data_move(input_x2_L1, input_x2_cast_ub, 0, 1, | |||
| no_tile * ko_tile_inner * c0 * c0 * fp16_size // blocksize, 0, 0) | |||
| # input_x1 -> input_x1_L1 | |||
| input_x1_L1 = tik_instance.Tensor(input_x1.dtype, [ko_tile_inner, mo_tile, c0, c0], | |||
| name="input_x1_L1", scope=tik.scope_cbuf) | |||
| tik_instance.data_move(input_x1_L1, | |||
| input_x1[k_idx, | |||
| core_m * mo_tile, 0, 0], | |||
| 0, ko_tile_inner, mo_tile * c0 * c0 * fp16_size // blocksize, | |||
| (mo - mo_tile) * c0 * c0 * fp16_size // blocksize, 0) | |||
| # input_x2_L1 -> input_x2_L0B | |||
| input_x2_L0B = tik_instance.Tensor("float16", [ko_tile_inner, no_tile, c0, c0], | |||
| name="input_x2_L0B", scope=tik.scope_cb) | |||
| with tik_instance.for_range(0, ko_tile_inner) as cc2: | |||
| tik_instance.load2dv1(input_x2_L0B[cc2, 0, 0, 0], input_x2_L1[0, cc2, 0, 0], 0, no_tile, | |||
| ko_tile_inner, | |||
| 0, True) | |||
| # input_x1_L1 -> input_x1_L0A | |||
| input_x1_L0A = tik_instance.Tensor(input_x1.dtype, [mo_tile, ko_tile_inner, c0, c0], | |||
| name="input_x1_L0A", scope=tik.scope_ca) | |||
| with tik_instance.for_range(0, mo_tile) as cc1: | |||
| tik_instance.load2dv1(input_x1_L0A[cc1, 0, 0, 0], input_x1_L1[0, cc1, 0, 0], 0, ko_tile_inner, | |||
| mo_tile, 0, False) | |||
| with tik_instance.if_scope(thread_idx_k == 0): | |||
| tik_instance.mmad(res_L0C, input_x1_L0A, input_x2_L0B, mo_tile * c0, | |||
| ko_tile_inner * c0, no_tile * c0, 0) | |||
| with tik_instance.else_scope(): | |||
| tik_instance.mmad(res_L0C, input_x1_L0A, input_x2_L0B, mo_tile * c0, | |||
| ko_tile_inner * c0, no_tile * c0, 1) | |||
| res_ub = tik_instance.Tensor(input_x1.dtype, [no_tile, mo_tile, c0, c0], | |||
| name="resMatmul_ub", scope=tik.scope_ubuf) | |||
| tik_instance.data_move(res_ub, res_L0C, 0, 1, no_tile * mo_tile, 0, 0, 1) | |||
| tik_instance.data_move(res[(core_n * loop_n_num + cc_n) * no_tile, core_m * mo_tile, 0, 0], | |||
| res_ub, 0, no_tile, | |||
| mo_tile * c0 * c0 * fp16_size // blocksize, 0, | |||
| (mo - mo_tile) * c0 * c0 * fp16_size // blocksize) | |||
| @@ -0,0 +1,247 @@ | |||
| #!/usr/bin/env python | |||
| # -*- coding:utf-8 -*- | |||
| """ | |||
| copyright 2020 Huawei Technologies Co., Ltd | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License == distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| matmul | |||
| """ | |||
| from __future__ import absolute_import | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| from te import tik | |||
| from topi.cce import util | |||
| # General limitation of the size for input shape: 2**31 | |||
| SHAPE_SIZE_LIMIT = 2147483648 | |||
| NoneType = type(None) | |||
| cus_matmul_cube_fracz_right_mul_op_info = TBERegOp("CusMatMulCubeFraczRightMul") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("matmulcubefraczrightmul.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("CusMatMulCubeFraczRightMul") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .input(2, "x3", False, "required", "all") \ | |||
| .input(3, "x4", False, "optional", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.F16_Default, DataType.F32_Default, DataType.F16_Default, | |||
| DataType.F32_FracZ) \ | |||
| .get_op_info() | |||
| @op_info_register(cus_matmul_cube_fracz_right_mul_op_info) | |||
| def CusMatMulCubeFraczRightMul(input_x1, input_x2, input_x3, bias=None, output_y={}, trans_a=False, trans_b=False, | |||
| kernel_name="matmulcube"): | |||
| """CusMatMulCubeFraczRightMul""" | |||
| if util.get_product_version() == util.VERSION_MINI: | |||
| tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) | |||
| else: | |||
| tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) | |||
| input_x1_shape = input_x1.get("shape") | |||
| input_x1_dtype = input_x1.get("dtype").lower() | |||
| input_x2_shape = input_x2.get("shape") | |||
| input_x2_dtype = input_x2.get("dtype").lower() | |||
| input_x3_shape = input_x3.get("shape") | |||
| input_x3_dtype = input_x3.get("dtype").lower() | |||
| output_shape = output_y.get("shape") | |||
| Supported = [((72, 8, 16, 16), "float16", (72, 72, 16, 16), "float16", (1,), "float32"), | |||
| ((32, 8, 16, 16), "float16", (32, 32, 16, 16), "float16", (1,), "float32"), | |||
| ((8, 32, 16, 16), "float16", (8, 8, 16, 16), "float16", (1,), "float32"), | |||
| ((4, 4, 16, 16), "float16", (4, 4, 16, 16), "float16", (1,), "float32"), | |||
| ((4, 16, 16, 16), 'float16', (4, 4, 16, 16), 'float16', (1,), 'float32'), | |||
| ((49, 4, 16, 16), 'float16', (49, 49, 16, 16), 'float16', (1,), 'float32'), | |||
| ((36, 4, 16, 16), 'float16', (36, 36, 16, 16), 'float16', (1,), 'float32'), | |||
| ((64, 16, 16, 16), 'float16', (64, 64, 16, 16), 'float16', (1,), 'float32'), | |||
| ((32, 64, 16, 16), 'float16', (32, 32, 16, 16), 'float16', (1,), 'float32'), | |||
| ((32, 16, 16, 16), 'float16', (32, 32, 16, 16), 'float16', (1,), 'float32'), | |||
| ((16, 32, 16, 16), 'float16', (16, 16, 16, 16), 'float16', (1,), 'float32'), | |||
| ((16, 8, 16, 16), 'float16', (16, 16, 16, 16), 'float16', (1,), 'float32'), | |||
| ((16, 4, 16, 16), 'float16', (16, 16, 16, 16), 'float16', (1,), 'float32'), | |||
| ((288, 32, 16, 16), 'float16', (288, 288, 16, 16), 'float16', (1,), 'float32'), | |||
| ((144, 16, 16, 16), 'float16', (144, 144, 16, 16), 'float16', (1,), 'float32'), | |||
| ((128, 32, 16, 16), 'float16', (128, 128, 16, 16), 'float16', (1,), 'float32'), | |||
| ((64, 128, 16, 16), 'float16', (64, 64, 16, 16), 'float16', (1,), 'float32'), | |||
| ((32, 128, 16, 16), 'float16', (32, 32, 16, 16), 'float16', (1,), 'float32'), | |||
| ((64, 32, 16, 16), 'float16', (64, 64, 16, 16), 'float16', (1,), 'float32'), | |||
| ((16, 64, 16, 16), 'float16', (16, 16, 16, 16), 'float16', (1,), 'float32')] | |||
| input_shape = ( | |||
| tuple(input_x1_shape), input_x1_dtype, tuple(input_x2_shape), input_x2_dtype, tuple(input_x3_shape), input_x3_dtype) | |||
| if input_shape not in Supported: | |||
| raise RuntimeError("input_shape %s is not supported" % str(input_shape)) | |||
| input_x1 = tik_instance.Tensor("float16", input_x1_shape, name="left_matrix", scope=tik.scope_gm) | |||
| input_x2 = tik_instance.Tensor("float16", input_x2_shape, name="right_matrix", scope=tik.scope_gm) | |||
| input_x3 = tik_instance.Tensor("float32", input_x3_shape, name="matrix_max", scope=tik.scope_gm) | |||
| resMatmul = tik_instance.Tensor("float32", output_shape, name="output", scope=tik.scope_gm) | |||
| cus_cube_matmul_right_mul(tik_instance, input_x1, input_x2, input_x3, resMatmul) | |||
| tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2, input_x3], outputs=[resMatmul]) | |||
| return tik_instance | |||
| def cus_cube_matmul_right_mul(tik_instance, input_x1, input_x2, input_x3, | |||
| res): | |||
| """cus_cube_matmul_right_mul""" | |||
| diag_size = 128 | |||
| ko, mo, _, _ = input_x1.shape | |||
| no, ko, _, _ = input_x2.shape | |||
| c0 = input_x1.shape[-1] | |||
| diag_outer = diag_size // c0 | |||
| if [input_x1.shape[-1], input_x1.shape[-2], input_x2.shape[-1], input_x2.shape[-2]] != [c0, c0, c0, c0]: | |||
| raise ValueError("shape of input_x1 or input_x2 is not supported!") | |||
| def get_cus_tile_info(input_x1, input_x2, input_x3): | |||
| """get_cus_tile_info""" | |||
| input_shape = (tuple(input_x1.shape), input_x1.dtype, tuple(input_x2.shape), input_x2.dtype, | |||
| tuple(input_x3.shape), input_x3.dtype) | |||
| tile_map = { | |||
| # no diag opt: | |||
| ((8, 32, 16, 16), "float16", (8, 8, 16, 16), "float16", (1,), "float32"): (4, 8, 2, 8, 4), | |||
| ((4, 4, 16, 16), "float16", (4, 4, 16, 16), "float16", (1,), "float32"): (1, 4, 1, 4, 4), | |||
| ((4, 16, 16, 16), 'float16', (4, 4, 16, 16), 'float16', (1,), 'float32'): (1, 4, 2, 16, 2), | |||
| ((49, 4, 16, 16), 'float16', (49, 49, 16, 16), 'float16', (1,), 'float32'): (1, 7, 7, 4, 7), | |||
| ((36, 4, 16, 16), 'float16', (36, 36, 16, 16), 'float16', (1,), 'float32'): (2, 6, 3, 2, 12), | |||
| # diag opt: | |||
| ((288, 32, 16, 16), 'float16', (288, 288, 16, 16), 'float16', (1,), 'float32'): (16, 8, 8, 2, 12), | |||
| } | |||
| maxblocknum = 32 | |||
| diag_opt = False | |||
| if input_x2.shape[0] * input_x2.shape[3] > diag_size and input_x2.shape[0] % diag_outer == 0: | |||
| diag_opt = True | |||
| if input_shape in tile_map: | |||
| mo_tile_, ko_tile_, no_tile_, core_m_num_, core_n_num_ = tile_map[input_shape] | |||
| elif diag_opt: | |||
| ko_tile_ = diag_outer | |||
| no_tile_ = ko_tile_ | |||
| core_n_num_ = no // no_tile_ | |||
| core_m_num_max = maxblocknum // core_n_num_ | |||
| mo_tile_ = -1 | |||
| core_m_num_ = -1 | |||
| for i in range(core_m_num_max, 0, -1): | |||
| if mo % i == 0: | |||
| core_m_num_ = i | |||
| mo_tile_ = mo // i | |||
| break | |||
| if mo_tile_ == -1: | |||
| raise ValueError("no valid tile be found!") | |||
| while mo_tile_ > 16: | |||
| mo_tile_ = mo_tile_ // 2 | |||
| else: | |||
| raise ValueError("please add tile config to the tile_map") | |||
| print("shape: %s, tile: %s" % (input_shape, str((mo_tile_, ko_tile_, no_tile_, core_m_num_, core_n_num_, | |||
| diag_opt)))) | |||
| return mo_tile_, ko_tile_, no_tile_, core_m_num_, core_n_num_, diag_opt | |||
| mo_tile, ko_tile, no_tile, core_m_num, core_n_num, diag_opt = get_cus_tile_info(input_x1, input_x2, input_x3) | |||
| fp32_size = 4 | |||
| fp16_size = 2 | |||
| blocksize = 32 | |||
| vectorfp32_size = 64 | |||
| loop_n_num_total = no // no_tile | |||
| loop_m_num_total = mo // mo_tile | |||
| if loop_n_num_total % core_n_num != 0 or loop_m_num_total % core_m_num != 0: | |||
| raise ValueError("Does not support this scenario!") | |||
| loop_n_num = loop_n_num_total // core_n_num | |||
| loop_m_num = loop_m_num_total // core_m_num | |||
| block_num = core_n_num * core_m_num | |||
| loop_k_num = ko // ko_tile | |||
| if diag_opt: | |||
| loop_k_num = diag_outer // ko_tile | |||
| # double buffer: | |||
| thread_num_k = 2 | |||
| if ko_tile % 2 == 0: | |||
| loop_k_num *= thread_num_k | |||
| ko_tile_inner = ko_tile // thread_num_k | |||
| else: | |||
| ko_tile_inner = ko_tile | |||
| ko_tile *= thread_num_k | |||
| with tik_instance.for_range(0, block_num, block_num=block_num) as block_idx: | |||
| core_m = block_idx // core_n_num | |||
| core_n = block_idx % core_n_num | |||
| with tik_instance.for_range(0, loop_m_num) as cc_m: | |||
| with tik_instance.for_range(0, loop_n_num) as cc_n: | |||
| res_L0C = tik_instance.Tensor("float32", [no_tile, mo_tile, c0, c0], | |||
| name="resMatmul_L0C", scope=tik.scope_cc) | |||
| with tik_instance.for_range(0, loop_k_num, thread_num=thread_num_k) as thread_idx_k: | |||
| if diag_opt: | |||
| k_idx = (core_n * loop_n_num + cc_n) * no_tile + thread_idx_k * ko_tile_inner | |||
| else: | |||
| k_idx = thread_idx_k * ko_tile_inner | |||
| # input_x1 -> input_x1_L1 | |||
| input_x1_L1 = tik_instance.Tensor(input_x1.dtype, [ko_tile_inner, mo_tile, c0, c0], | |||
| name="input_x1_L1", scope=tik.scope_cbuf) | |||
| tik_instance.data_move(input_x1_L1, | |||
| input_x1[k_idx, | |||
| (core_m * loop_m_num + cc_m) * mo_tile, 0, 0], | |||
| 0, ko_tile_inner, mo_tile * c0 * c0 * fp16_size // blocksize, | |||
| (mo - mo_tile) * c0 * c0 * fp16_size // blocksize, 0) | |||
| # input_x2 -> input_x2_L1 | |||
| input_x2_L1 = tik_instance.Tensor("float16", [no_tile, ko_tile_inner, c0, c0], | |||
| name="input_x2_L1", scope=tik.scope_cbuf) | |||
| tik_instance.data_move(input_x2_L1, | |||
| input_x2[(core_n * loop_n_num + cc_n) * no_tile, | |||
| k_idx, 0, 0], | |||
| 0, no_tile, ko_tile_inner * c0 * c0 * fp16_size // blocksize, | |||
| (ko - ko_tile_inner) * c0 * c0 * fp16_size // blocksize, 0) | |||
| # input_x1_L1 -> input_x1_L0A | |||
| input_x1_L0A = tik_instance.Tensor(input_x1.dtype, [mo_tile, ko_tile_inner, c0, c0], | |||
| name="input_x1_L0A", scope=tik.scope_ca) | |||
| with tik_instance.for_range(0, mo_tile) as cc1: | |||
| tik_instance.load2dv1(input_x1_L0A[cc1, 0, 0, 0], input_x1_L1[0, cc1, 0, 0], 0, ko_tile_inner, | |||
| mo_tile, 0, False) | |||
| # input_x2_L1 -> input_x2_L0B | |||
| input_x2_L0B = tik_instance.Tensor("float16", [ko_tile_inner, no_tile, c0, c0], | |||
| name="input_x2_L0B", scope=tik.scope_cb) | |||
| with tik_instance.for_range(0, ko_tile_inner) as cc2: | |||
| tik_instance.load2dv1(input_x2_L0B[cc2, 0, 0, 0], input_x2_L1[0, cc2, 0, 0], 0, no_tile, | |||
| ko_tile_inner, | |||
| 0, True) | |||
| with tik_instance.if_scope(thread_idx_k == 0): | |||
| tik_instance.mmad(res_L0C, input_x1_L0A, input_x2_L0B, mo_tile * c0, | |||
| ko_tile_inner * c0, no_tile * c0, 0) | |||
| with tik_instance.else_scope(): | |||
| tik_instance.mmad(res_L0C, input_x1_L0A, input_x2_L0B, mo_tile * c0, | |||
| ko_tile_inner * c0, no_tile * c0, 1) | |||
| res_ub = tik_instance.Tensor("float32", [no_tile, mo_tile, c0, c0], | |||
| name="resMatmul_ub", scope=tik.scope_ubuf) | |||
| tik_instance.data_move(res_ub, res_L0C, 0, 1, no_tile * mo_tile, 0, 0) | |||
| input_3_local_UB = tik_instance.Tensor("float32", (8,), scope=tik.scope_ubuf, name="input_3_local_UB") | |||
| tik_instance.data_move(input_3_local_UB, input_x3, 0, 1, 1, 0, 0) | |||
| matrix_max_scalar = tik_instance.Scalar("float32") | |||
| matrix_max_scalar.set_as(input_3_local_UB[0]) | |||
| repeate_num = no_tile * mo_tile * c0 * c0 // vectorfp32_size | |||
| repeate_times_max = 255 | |||
| count = 0 | |||
| while repeate_num > repeate_times_max: | |||
| tik_instance.vmuls(vectorfp32_size, | |||
| res_ub[count * repeate_times_max * vectorfp32_size], | |||
| res_ub[count * repeate_times_max * vectorfp32_size], | |||
| matrix_max_scalar, repeate_times_max, 1, 1, 8, 8) | |||
| repeate_num -= repeate_times_max | |||
| count += 1 | |||
| tik_instance.vmuls(vectorfp32_size, | |||
| res_ub[count * repeate_times_max * vectorfp32_size], | |||
| res_ub[count * repeate_times_max * vectorfp32_size], | |||
| matrix_max_scalar, repeate_num, 1, 1, 8, 8) | |||
| tik_instance.data_move(res[(core_n * loop_n_num + cc_n) * no_tile, | |||
| (core_m * loop_m_num + cc_m) * mo_tile, 0, 0], | |||
| res_ub, 0, no_tile, | |||
| mo_tile * c0 * c0 * fp32_size // blocksize, 0, | |||
| (mo - mo_tile) * c0 * c0 * fp32_size // blocksize) | |||
| @@ -0,0 +1,397 @@ | |||
| #!/usr/bin/env python | |||
| # -*- coding:utf-8 -*- | |||
| """ | |||
| copyright 2020 Huawei Technologies Co., Ltd | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License == distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| matmul | |||
| """ | |||
| from __future__ import absolute_import | |||
| from impl.matmul_vector import matmul_vector_cce | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| import te.lang.cce | |||
| import te.platform.cce_params as cce | |||
| from te import tvm | |||
| from topi import generic | |||
| from topi.cce import util | |||
| # General limitation of the size for input shape: 2**31 | |||
| SHAPE_SIZE_LIMIT = 2147483648 | |||
| NoneType = type(None) | |||
| matmul_cube_op_info = TBERegOp("CusMatMulCube") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("matmulcube.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("CusMatMulCube") \ | |||
| .partial_flag(True) \ | |||
| .attr("transpose_a", "required", "bool", "all") \ | |||
| .attr("transpose_b", "required", "bool", "all") \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .input(2, "x3", False, "optional", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F32_FracNZ) \ | |||
| .get_op_info() | |||
| # pylint: disable=locally-disabled,too-many-arguments,too-many-branches, too-many-statements, too-many-locals, | |||
| def _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b): | |||
| """ | |||
| Check the given input if legal | |||
| Parameters: | |||
| shape_a: list or tuple | |||
| Shape of the first tensor a with rank > 1 | |||
| shape_b: list or tuple | |||
| Shape of the second tensor b with the same type with a, | |||
| and shape_a, shape_b must be 2 dims | |||
| shape_bias: list or tuple | |||
| Shape of bias, only support the input data format with ND | |||
| src_dtype: str | |||
| The data type of input, support "float32", "float16" | |||
| trans_a: bool | |||
| If True, shape_a == transposed before multiplication | |||
| trans_b: bool | |||
| If True, shape_b == transposed before multiplication | |||
| Returns None | |||
| """ | |||
| shape_len = len(shape_a) | |||
| src_dtype = src_dtype.lower() | |||
| k_block_size = cce.BLOCK_REDUCE | |||
| check_list = ("float16") | |||
| if src_dtype not in check_list: | |||
| raise RuntimeError("matmul_cce only support %s while src_dtype == %s" | |||
| % (",".join(check_list), src_dtype)) | |||
| if shape_len != len(shape_b): | |||
| raise RuntimeError("length of a and b are not equal") | |||
| if shape_len != 2: | |||
| raise RuntimeError( | |||
| "length of shape must be 2, more than 2 dimensions should use batch_matmul now!") | |||
| is_gevm = True if shape_a[-2] == 1 or shape_a[-1] == 1 else False | |||
| is_gemv = True if shape_b[-2] == 1 or shape_b[-1] == 1 else False | |||
| if trans_a: | |||
| m_shape = shape_a[shape_len - 1] | |||
| km_shape = shape_a[shape_len - 2] | |||
| else: | |||
| m_shape = shape_a[shape_len - 2] | |||
| km_shape = shape_a[shape_len - 1] | |||
| if trans_b: | |||
| kn_shape = shape_b[shape_len - 1] | |||
| n_shape = shape_b[shape_len - 2] | |||
| else: | |||
| kn_shape = shape_b[shape_len - 2] | |||
| n_shape = shape_b[shape_len - 1] | |||
| if m_shape == 1: | |||
| if n_shape == 1: | |||
| raise RuntimeError("input shape M and N can't both be 1") | |||
| if km_shape != kn_shape: | |||
| raise RuntimeError("reduce axis not same") | |||
| if m_shape % cce.BLOCK_IN != 0 and m_shape != 1: | |||
| raise RuntimeError( | |||
| "input shape M should be 1 or multiple of %d" % cce.BLOCK_IN) | |||
| if m_shape != 1: | |||
| if n_shape == 1: | |||
| if km_shape % (cce.BLOCK_IN * cce.BLOCK_IN) != 0: | |||
| raise RuntimeError("input shape K1 should be multiple of %d" | |||
| % (cce.BLOCK_IN * cce.BLOCK_IN)) | |||
| elif km_shape % k_block_size != 0: | |||
| raise RuntimeError( | |||
| "input shape K1 should be multiple of %d" % cce.BLOCK_IN) | |||
| else: | |||
| if km_shape % (cce.BLOCK_IN * cce.BLOCK_IN) != 0: | |||
| raise RuntimeError("input shape K1 should be multiple of %d" | |||
| % (cce.BLOCK_IN * cce.BLOCK_IN)) | |||
| if n_shape % cce.BLOCK_IN != 0 and n_shape != 1: | |||
| raise RuntimeError("input shape N should be 1 or multiple of %d" % cce.BLOCK_IN) | |||
| if len(shape_bias): | |||
| if len(shape_bias) == 1: | |||
| if is_gevm or is_gemv: | |||
| if shape_bias[0] != m_shape * n_shape: | |||
| raise RuntimeError("broadcast case shape bias for gemv must be equal m*n") | |||
| else: | |||
| if shape_bias[0] != n_shape: | |||
| raise RuntimeError("broadcast bias shape must be equal to shape n") | |||
| elif len(shape_bias) == shape_len: | |||
| if [i for i in shape_bias[-2:]] != [m_shape, n_shape]: | |||
| raise RuntimeError("non broadcast bias shape must be same as output shape") | |||
| else: | |||
| raise RuntimeError("unsupport input shape now for batch bias case") | |||
| def _get_bias(shape_bias): | |||
| """_get_bias""" | |||
| bias_length = shape_bias[0] | |||
| if bias_length % 16 == 0: | |||
| return shape_bias | |||
| else: | |||
| bias_length = (bias_length // 16) * 16 + 16 | |||
| shape_bias = [] | |||
| shape_bias.append(bias_length) | |||
| return shape_bias | |||
| def _get_input_shape(shape_x): | |||
| """_get_input_shape""" | |||
| dim_a = shape_x[0] | |||
| dim_b = shape_x[1] | |||
| res = [] | |||
| if dim_a % 16 != 0: | |||
| dim_a = (dim_a // 16) * 16 + 16 | |||
| res.append(dim_a) | |||
| else: | |||
| res.append(dim_a) | |||
| if dim_b % 16 != 0: | |||
| dim_b = (dim_b // 16) * 16 + 16 | |||
| res.append(dim_b) | |||
| else: | |||
| res.append(dim_b) | |||
| return res | |||
| def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): | |||
| """check_supported""" | |||
| shape_a = input_x1.get("shape") | |||
| shape_b = input_x2.get("shape") | |||
| print("shape_a: ", shape_a) | |||
| print("shape_b: ", shape_b) | |||
| src_dtype = input_x1.get("dtype") | |||
| util.check_kernel_name(kernel_name) | |||
| util.check_shape_rule(shape_a) | |||
| util.check_shape_rule(shape_b) | |||
| util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) | |||
| util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) | |||
| try: | |||
| trans_a_f = bool(1 - trans_a) | |||
| if src_dtype == "float32" or src_dtype == "int32": | |||
| if len(shape_a) != 2 and len(shape_b) != 2: | |||
| return False | |||
| if trans_b: | |||
| if shape_b[0] == 1: | |||
| return False | |||
| else: | |||
| if shape_b[1] == 1: | |||
| return False | |||
| if trans_a: | |||
| if trans_b: | |||
| if shape_a[0] != shape_b[1]: | |||
| return False | |||
| elif shape_a[0] != shape_b[0]: | |||
| return False | |||
| elif trans_b: | |||
| if shape_a[1] != shape_b[1]: | |||
| return False | |||
| elif shape_a[1] != shape_b[0]: | |||
| return False | |||
| if trans_a_f and trans_b and shape_b[1] == 1: | |||
| return False | |||
| if src_dtype == "float16": | |||
| if len(shape_a) != 2 and len(shape_b) != 2: | |||
| return False | |||
| if trans_a: | |||
| m_shape = shape_a[1] | |||
| k_shape = shape_a[0] | |||
| else: | |||
| m_shape = shape_a[0] | |||
| k_shape = shape_a[1] | |||
| if trans_b: | |||
| n_shape = shape_b[0] | |||
| k_b_shape = shape_b[1] | |||
| else: | |||
| n_shape = shape_b[1] | |||
| k_b_shape = shape_b[0] | |||
| if k_shape != k_b_shape: | |||
| return False | |||
| if m_shape == 1 or n_shape == 1: | |||
| if k_shape % 256 != 0: | |||
| return False | |||
| except RuntimeError as e: | |||
| return False | |||
| return True | |||
| # pylint: disable=locally-disabled,too-many-arguments, too-many-locals, too-many-statements | |||
| @op_info_register(matmul_cube_op_info) | |||
| def CusMatMulCube(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): | |||
| """ | |||
| calculating matrix multiplication with bias, C = A*B + bias, support input | |||
| data with fractal format. | |||
| Parameters: | |||
| shape_a: list or tuple | |||
| Shape of the first tensor a with rank > 1 | |||
| shape_b: list or tuple | |||
| Shape of the second tensor b with the same type with a, | |||
| and shape_a, shape_b must be 2 dims | |||
| src_dtype: str | |||
| The data type of input, support "float32", "float16" | |||
| dst_dtype: str | |||
| The data type of output, support "float32", "float16" | |||
| trans_a: bool | |||
| If True, shape_a == transposed before multiplication | |||
| trans_b: bool | |||
| If True, shape_b == transposed before multiplication | |||
| is_fractal: bool | |||
| If True, the input data format of a and b must be fractal format | |||
| shape_bias: list or tuple | |||
| Shape of bias, only support the input data format with ND | |||
| Returns | |||
| ------- | |||
| None | |||
| """ | |||
| shape_a = input_x1.get("ori_shape") | |||
| shape_b = input_x2.get("ori_shape") | |||
| if shape_a is not None: | |||
| if len(shape_a) < 2: | |||
| shape_a = input_x1.get("shape") | |||
| if shape_b is not None: | |||
| if len(shape_b) < 2: | |||
| shape_b = input_x2.get("shape") | |||
| shape_a = list(shape_a) | |||
| shape_b = list(shape_b) | |||
| if input_x1.get("format") == "FRACTAL_NZ": | |||
| shape_a = _get_input_shape(shape_a) | |||
| shape_b = _get_input_shape(shape_b) | |||
| util.check_kernel_name(kernel_name) | |||
| util.check_shape_rule(shape_a) | |||
| util.check_shape_rule(shape_b) | |||
| util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) | |||
| util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) | |||
| if input_x1.get("format") == "FRACTAL_NZ": | |||
| shape_a = [shape_a[1], shape_a[0]] | |||
| trans_a = bool(1 - trans_a) | |||
| if input_x2.get("format") == "FRACTAL_NZ": | |||
| shape_b = [shape_b[1], shape_b[0]] | |||
| trans_b = bool(1 - trans_b) | |||
| shape_bias = () | |||
| if bias is not None and bool(bias): | |||
| shape_bias = bias.get("shape") | |||
| shape_bias = list(shape_bias) | |||
| shape_bias = _get_bias(shape_bias) | |||
| src_dtype = input_x1.get("dtype").lower() | |||
| dst_dtype = output_y.get("dtype").lower() | |||
| if src_dtype == "float32" or src_dtype == "int32": | |||
| matmul_vector_cce(shape_a, shape_b, src_dtype, trans_a, trans_b, shape_bias, kernel_name) | |||
| return | |||
| _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b) | |||
| m_shape = shape_a[len(shape_a) - 2] | |||
| km_shape = shape_a[len(shape_a) - 1] | |||
| kn_shape = shape_b[len(shape_a) - 2] | |||
| n_shape = shape_b[len(shape_a) - 1] | |||
| if src_dtype == "float16": | |||
| block_reduce = cce.BLOCK_REDUCE | |||
| block_in = cce.BLOCK_IN | |||
| block_out = cce.BLOCK_OUT | |||
| if trans_a and km_shape == 1: | |||
| block_in = cce.BLOCK_VECTOR | |||
| if not trans_a and m_shape == 1: | |||
| block_in = cce.BLOCK_VECTOR | |||
| if trans_b and kn_shape == 1: | |||
| block_out = cce.BLOCK_VECTOR | |||
| if not trans_b and n_shape == 1: | |||
| block_out = cce.BLOCK_VECTOR | |||
| if trans_a: | |||
| shape_a_temp = (m_shape // block_reduce, km_shape // block_in, block_reduce, block_in) | |||
| else: | |||
| shape_a_temp = (m_shape // block_in, km_shape // block_reduce, block_in, block_reduce) | |||
| if trans_b: | |||
| shape_b_temp = (kn_shape // block_out, n_shape // block_reduce, block_reduce, block_out) | |||
| else: | |||
| shape_b_temp = (kn_shape // block_reduce, n_shape // block_out, block_out, block_reduce) | |||
| if input_x1.get("format") == "FORMAT_FRACTAL_Z": | |||
| shape_a_temp = (shape_a_temp[0], shape_a_temp[1], shape_a_temp[2], shape_a_temp[3]) | |||
| format_a = "fractal" | |||
| elif input_x1.get("format") == "FRACTAL_NZ": | |||
| shape_a_temp = (shape_a_temp[0], shape_a_temp[1], shape_a_temp[2], shape_a_temp[3]) | |||
| format_a = "FRACTAL_NZ" | |||
| else: | |||
| shape_a_temp = (shape_a[len(shape_a) - 2], shape_a[len(shape_a) - 1]) | |||
| format_a = "ND" | |||
| if input_x2.get("format") == "FORMAT_FRACTAL_Z": | |||
| shape_b_temp = (shape_b_temp[0], shape_b_temp[1], shape_b_temp[2], shape_b_temp[3]) | |||
| format_b = "fractal" | |||
| elif input_x2.get("format") == "FRACTAL_NZ": | |||
| shape_b_temp = (shape_b_temp[0], shape_b_temp[1], shape_b_temp[2], shape_b_temp[3]) | |||
| format_b = "FRACTAL_NZ" | |||
| else: | |||
| shape_b_temp = (shape_b[len(shape_b) - 2], shape_b[len(shape_b) - 1]) | |||
| format_b = "ND" | |||
| tensor_bias = None | |||
| tensor_a = tvm.placeholder(shape_a_temp, name='tensor_a', | |||
| dtype=src_dtype) | |||
| tensor_b = tvm.placeholder(shape_b_temp, name='tensor_b', | |||
| dtype=src_dtype) | |||
| if len(shape_bias) > 0: | |||
| tensor_bias = tvm.placeholder(shape_bias, name='tensor_bias', | |||
| dtype=dst_dtype) | |||
| result = te.lang.cce.matmul(tensor_a, tensor_b, trans_a, trans_b, format_a=format_a, | |||
| format_b=format_b, dst_dtype=dst_dtype, tensor_bias=tensor_bias) | |||
| with tvm.target.cce(): | |||
| schedule = generic.auto_schedule(result) | |||
| tensor_list = [tensor_a, tensor_b, result] | |||
| if len(shape_bias) > 0: | |||
| tensor_list = [tensor_a, tensor_b, tensor_bias, result] | |||
| config = {"print_ir": False, | |||
| "name": kernel_name, | |||
| "tensor_list": tensor_list} | |||
| te.lang.cce.cce_build_code(schedule, config) | |||
| @@ -0,0 +1,81 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """CusMatrixCombine""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| from te import tik | |||
| from topi.cce import util | |||
| cus_matrix_combine_op_info = TBERegOp("CusMatrixCombine") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("matrixcombine.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("CusMatrixCombine") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(cus_matrix_combine_op_info) | |||
| def CusMatrixCombine(input_x, output, kernel_name="matrix_combine"): | |||
| """CusMatrixCombine""" | |||
| input_x_shape = input_x.get("shape") | |||
| output_shape = output.get("shape") | |||
| split_dim = 128 | |||
| if util.get_product_version() == util.VERSION_MINI: | |||
| tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) | |||
| else: | |||
| tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) | |||
| input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) | |||
| res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) | |||
| blocks = 32 | |||
| matrix_dim = input_x_shape[0] * input_x_shape[1] | |||
| if input_x_shape[0] == 1 and input_x_shape[1] == 64: | |||
| tiling_dim = 2 | |||
| bs = 1 | |||
| with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: | |||
| input_x_ub = tik_instance.Tensor("float32", (tiling_dim, matrix_dim), name="input_x_ub", | |||
| scope=tik.scope_ubuf) | |||
| tik_instance.data_move(input_x_ub, input_x[0, block_index * tiling_dim, 0], 0, 1, 16, 0, 0) | |||
| tik_instance.data_move(res[block_index * tiling_dim, 0], input_x_ub, 0, 1, 16, 0, 0) | |||
| else: | |||
| tiling_dim = 4 | |||
| bs = input_x_shape[0] | |||
| with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: | |||
| input_x_ub = tik_instance.Tensor("float32", (tiling_dim, matrix_dim), name="input_x_ub", | |||
| scope=tik.scope_ubuf) | |||
| zero = tik_instance.Scalar("float32") | |||
| zero.set_as(0.0) | |||
| with tik_instance.for_range(0, bs) as i: | |||
| repeat_real = tiling_dim * matrix_dim // 64 | |||
| if repeat_real <= 255: | |||
| tik_instance.vector_dup(64, input_x_ub, zero, repeat_real, 1, 8) | |||
| else: | |||
| repeat_1 = 255 | |||
| repeat_2 = repeat_real - 255 | |||
| tik_instance.vector_dup(64, input_x_ub, zero, repeat_1, 1, 8) | |||
| tik_instance.vector_dup(64, input_x_ub[255 * 64], zero, repeat_2, 1, 8) | |||
| with tik_instance.for_range(0, tiling_dim) as j: | |||
| tik_instance.data_move(input_x_ub[j, split_dim * i], input_x[i, block_index * tiling_dim + j, 0], 0, | |||
| 1, 16, 0, 0) | |||
| tik_instance.data_move(res[i * split_dim + block_index * tiling_dim, 0], input_x_ub, 0, 1, | |||
| tiling_dim * matrix_dim * 4 // 32, 0, 0) | |||
| tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x], outputs=[res]) | |||
| return tik_instance | |||
| @@ -0,0 +1,289 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """CusTranspose02314""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| from te import tik | |||
| from topi.cce import util | |||
| cus_transpose02314_op_info = TBERegOp("CusTranspose02314") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("transpose02314.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("CusTranspose02314") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(cus_transpose02314_op_info) | |||
| def CusTranspose02314(input_x, output, kernel_name="transpose021354"): | |||
| """CusTranspose02314""" | |||
| input_x_shape = input_x.get("shape") | |||
| output_shape = output.get("shape") | |||
| perm = (0, 2, 3, 1, 4) | |||
| input_x_shape = tuple(input_x_shape) | |||
| support_shape = [(32, 128, 7, 7, 16), | |||
| (32, 32, 7, 7, 16), | |||
| (32, 32, 14, 14, 16), | |||
| (32, 64, 14, 14, 16), | |||
| (32, 16, 14, 14, 16), | |||
| (32, 16, 28, 28, 16), | |||
| (32, 32, 28, 28, 16), | |||
| (32, 8, 28, 28, 16), | |||
| (32, 8, 56, 56, 16), | |||
| (32, 16, 56, 56, 16), | |||
| (32, 4, 56, 56, 16), | |||
| (32, 4, 112, 112, 16)] | |||
| if input_x_shape not in support_shape: | |||
| raise RuntimeError("input_shape %s is not supported" % str(input_x_shape)) | |||
| if util.get_product_version() == util.VERSION_MINI: | |||
| tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) | |||
| else: | |||
| tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) | |||
| input_x = tik_instance.Tensor("float16", input_x_shape, name="input_x", scope=tik.scope_gm) | |||
| res = tik_instance.Tensor("float16", output_shape, name="res", scope=tik.scope_gm) | |||
| dtype = "float16" | |||
| if tuple(input_x_shape) == (32, 4, 112, 112, 16): | |||
| with tik_instance.for_range(0, 32, block_num=32) as block_idx: | |||
| with tik_instance.for_range(0, 14) as cc1_db: | |||
| with tik_instance.for_range(0, 2, thread_num=2) as db_idx: | |||
| input_1_local_UB = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| T_transpose_local_UB = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| zero = tik_instance.Scalar(dtype="float16", init_value=0) | |||
| tik_instance.data_move(input_1_local_UB, | |||
| input_x[block_idx * 802816 + cc1_db * 14336 + 7168 * db_idx], 0, 4, 448, | |||
| 12096, 0) | |||
| with tik_instance.for_range(0, 448) as cc7: | |||
| with tik_instance.for_range(0, 4) as cc8: | |||
| tik_instance.vadds(16, T_transpose_local_UB[cc7 * 64 + cc8 * 16], | |||
| input_1_local_UB[7168 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0) | |||
| tik_instance.data_move(res[block_idx * 802816 + cc1_db * 57344 + 28672 * db_idx], | |||
| T_transpose_local_UB, 0, 1, 1792, 0, 0) | |||
| elif tuple(input_x_shape) == (32, 4, 56, 56, 16): | |||
| with tik_instance.for_range(0, 32, block_num=32) as block_idx: | |||
| zero = tik_instance.Scalar(dtype="float16", init_value=0) | |||
| with tik_instance.for_range(0, 3) as cc1_db: | |||
| with tik_instance.for_range(0, 2, thread_num=2) as db_idx: | |||
| input_1_local_UB = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| T_transpose_local_UB = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| tik_instance.data_move(input_1_local_UB, | |||
| input_x[block_idx * 200704 + cc1_db * 14336 + 7168 * db_idx], 0, 4, 448, | |||
| 2688, 0) | |||
| with tik_instance.for_range(0, 448) as cc7: | |||
| with tik_instance.for_range(0, 4) as cc8: | |||
| tik_instance.vadds(16, T_transpose_local_UB[cc7 * 64 + cc8 * 16], | |||
| input_1_local_UB[7168 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0) | |||
| tik_instance.data_move(res[block_idx * 200704 + cc1_db * 57344 + 28672 * db_idx], | |||
| T_transpose_local_UB, 0, 1, 1792, 0, 0) | |||
| input_1_local_UB2 = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB2", scope=tik.scope_ubuf) | |||
| T_transpose_local_UB2 = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB2", | |||
| scope=tik.scope_ubuf) | |||
| tik_instance.data_move(input_1_local_UB2, input_x[block_idx * 200704 + 43008], 0, 4, 448, 2688, 0) | |||
| with tik_instance.for_range(0, 448) as cc72: | |||
| with tik_instance.for_range(0, 4) as cc82: | |||
| tik_instance.vadds(16, T_transpose_local_UB2[cc72 * 64 + cc82 * 16], | |||
| input_1_local_UB2[7168 * cc82 + cc72 * 16], zero, 1, 1, 1, 0, 0) | |||
| tik_instance.data_move(res[block_idx * 200704 + 172032], T_transpose_local_UB2, 0, 1, 1792, 0, 0) | |||
| elif tuple(input_x_shape) == (32, 16, 56, 56, 16): | |||
| with tik_instance.for_range(0, 32, block_num=32) as block_idx: | |||
| zero = tik_instance.Scalar(dtype="float16", init_value=0) | |||
| with tik_instance.for_range(0, 14) as cc1_db: | |||
| with tik_instance.for_range(0, 2, thread_num=2) as db_idx: | |||
| input_1_local_UB = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| T_transpose_local_UB = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| tik_instance.data_move(input_1_local_UB, | |||
| input_x[block_idx * 802816 + cc1_db * 3584 + 1792 * db_idx], 0, 16, 112, | |||
| 3024, 0) | |||
| with tik_instance.for_range(0, 112) as cc7: | |||
| with tik_instance.for_range(0, 16) as cc8: | |||
| tik_instance.vadds(16, T_transpose_local_UB[cc7 * 256 + cc8 * 16], | |||
| input_1_local_UB[1792 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0) | |||
| tik_instance.data_move(res[block_idx * 802816 + cc1_db * 57344 + 28672 * db_idx], | |||
| T_transpose_local_UB, 0, 1, 1792, 0, 0) | |||
| elif tuple(input_x_shape) == (32, 8, 56, 56, 16): | |||
| with tik_instance.for_range(0, 32, block_num=32) as block_idx: | |||
| zero = tik_instance.Scalar(dtype="float16", init_value=0) | |||
| with tik_instance.for_range(0, 7) as cc1_db: | |||
| with tik_instance.for_range(0, 2, thread_num=2) as db_idx: | |||
| input_1_local_UB = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| T_transpose_local_UB = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| tik_instance.data_move(input_1_local_UB, | |||
| input_x[block_idx * 401408 + cc1_db * 7168 + 3584 * db_idx], 0, 8, 224, 2912, | |||
| 0) | |||
| with tik_instance.for_range(0, 224) as cc7: | |||
| with tik_instance.for_range(0, 16) as cc8: | |||
| tik_instance.vadds(16, T_transpose_local_UB[cc7 * 128 + cc8 * 16], | |||
| input_1_local_UB[3584 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0) | |||
| tik_instance.data_move(res[block_idx * 401408 + cc1_db * 57344 + 28672 * db_idx], | |||
| T_transpose_local_UB, 0, 1, 1792, 0, 0) | |||
| elif tuple(input_x_shape) == (32, 8, 28, 28, 16): | |||
| with tik_instance.for_range(0, 32, block_num=32) as block_idx: | |||
| zero = tik_instance.Scalar(dtype="float16", init_value=0) | |||
| with tik_instance.for_range(0, 2) as cc1_db: | |||
| with tik_instance.for_range(0, 2, thread_num=2) as db_idx: | |||
| input_1_local_UB = tik_instance.Tensor(dtype, [25088], name="input_1_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| T_transpose_local_UB = tik_instance.Tensor(dtype, [25088], name="T_transpose_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| tik_instance.data_move(input_1_local_UB, | |||
| input_x[block_idx * 100352 + cc1_db * 6272 + 3136 * db_idx], 0, 8, 196, 588, | |||
| 0) | |||
| with tik_instance.for_range(0, 196) as cc7: | |||
| with tik_instance.for_range(0, 8) as cc8: | |||
| tik_instance.vadds(16, T_transpose_local_UB[cc7 * 128 + cc8 * 16], | |||
| input_1_local_UB[3136 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0) | |||
| tik_instance.data_move(res[block_idx * 100352 + cc1_db * 50176 + 25088 * db_idx], | |||
| T_transpose_local_UB, 0, 1, 1568, 0, 0) | |||
| elif tuple(input_x_shape) == (32, 32, 28, 28, 16): | |||
| with tik_instance.for_range(0, 32, block_num=32) as block_idx: | |||
| zero = tik_instance.Scalar(dtype="float16", init_value=0) | |||
| with tik_instance.for_range(0, 7) as cc1_db: | |||
| with tik_instance.for_range(0, 2, thread_num=2) as db_idx: | |||
| input_1_local_UB = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| T_transpose_local_UB = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| tik_instance.data_move(input_1_local_UB, input_x[block_idx * 401408 + cc1_db * 1792 + 896 * db_idx], | |||
| 0, 32, 56, 728, 0) | |||
| with tik_instance.for_range(0, 56) as cc7: | |||
| with tik_instance.for_range(0, 32) as cc8: | |||
| tik_instance.vadds(16, T_transpose_local_UB[cc7 * 512 + cc8 * 16], | |||
| input_1_local_UB[896 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0) | |||
| tik_instance.data_move(res[block_idx * 401408 + cc1_db * 57344 + 28672 * db_idx], | |||
| T_transpose_local_UB, 0, 1, 1792, 0, 0) | |||
| elif tuple(input_x_shape) == (32, 16, 28, 28, 16): | |||
| with tik_instance.for_range(0, 32, block_num=32) as block_idx: | |||
| zero = tik_instance.Scalar(dtype="float16", init_value=0) | |||
| with tik_instance.for_range(0, 3) as cc1_db: | |||
| with tik_instance.for_range(0, 2, thread_num=2) as db_idx: | |||
| input_1_local_UB = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| T_transpose_local_UB = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| tik_instance.data_move(input_1_local_UB, | |||
| input_x[block_idx * 200704 + cc1_db * 3584 + 1792 * db_idx], 0, 16, 112, 672, | |||
| 0) | |||
| with tik_instance.for_range(0, 112) as cc7: | |||
| with tik_instance.for_range(0, 16) as cc8: | |||
| tik_instance.vadds(16, T_transpose_local_UB[cc7 * 256 + cc8 * 16], | |||
| input_1_local_UB[1792 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0) | |||
| tik_instance.data_move(res[block_idx * 200704 + cc1_db * 57344 + 28672 * db_idx], | |||
| T_transpose_local_UB, 0, 1, 1792, 0, 0) | |||
| input_1_local_UB2 = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB2", scope=tik.scope_ubuf) | |||
| T_transpose_local_UB2 = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB2", | |||
| scope=tik.scope_ubuf) | |||
| tik_instance.data_move(input_1_local_UB2, input_x[block_idx * 200704 + 10752], 0, 16, 112, 672, 0) | |||
| with tik_instance.for_range(0, 112) as cc7: | |||
| with tik_instance.for_range(0, 16) as cc8: | |||
| tik_instance.vadds(16, T_transpose_local_UB2[cc7 * 256 + cc8 * 16], | |||
| input_1_local_UB2[1792 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0) | |||
| tik_instance.data_move(res[block_idx * 200704 + 172032], T_transpose_local_UB2, 0, 1, 1792, 0, 0) | |||
| elif tuple(input_x_shape) == (32, 16, 14, 14, 16): | |||
| with tik_instance.for_range(0, 32, block_num=32) as block_idx: | |||
| zero = tik_instance.Scalar(dtype="float16", init_value=0) | |||
| with tik_instance.for_range(0, 2, thread_num=2) as db_idx: | |||
| input_1_local_UB = tik_instance.Tensor(dtype, [25088], name="input_1_local_UB", scope=tik.scope_ubuf) | |||
| T_transpose_local_UB = tik_instance.Tensor(dtype, [25088], name="T_transpose_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| tik_instance.data_move(input_1_local_UB, input_x[block_idx * 50176 + 1568 * db_idx], 0, 16, 98, 98, 0) | |||
| with tik_instance.for_range(0, 98) as cc7: | |||
| with tik_instance.for_range(0, 16) as cc8: | |||
| tik_instance.vadds(16, T_transpose_local_UB[cc7 * 256 + cc8 * 16], | |||
| input_1_local_UB[1568 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0) | |||
| tik_instance.data_move(res[block_idx * 50176 + 25088 * db_idx], T_transpose_local_UB, 0, 1, 1568, 0, 0) | |||
| elif tuple(input_x_shape) == (32, 128, 7, 7, 16) and tuple(perm) == (0, 2, 3, 1, 4) and dtype == "float16": | |||
| with tik_instance.for_range(0, 32, block_num=32) as block_idx: | |||
| with tik_instance.for_range(0, 7, thread_num=2) as cc1: | |||
| input_x_ub = tik_instance.Tensor(dtype, [1, 128, 1, 7, 16], name="input_1_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| transpose_ub = tik_instance.Tensor(dtype, [1, 1, 7, 128, 16], name="transpose_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| tik_instance.data_move(input_x_ub, input_x[block_idx, 0, cc1, 0, 0], 0, 128, 7, 42, 0) | |||
| with tik_instance.for_range(0, 7) as cc7: | |||
| with tik_instance.for_range(0, 128) as cc8: | |||
| tik_instance.vadds(16, transpose_ub[0, 0, cc7, cc8, 0], input_x_ub[0, cc8, 0, cc7, 0], 0, | |||
| 1, 1, 1, 0, 0) | |||
| tik_instance.data_move(res[block_idx * 100352 + 14336 * cc1], transpose_ub, 0, 1, 896, 0, 0) | |||
| elif tuple(input_x_shape) == (32, 32, 7, 7, 16) and tuple(perm) == (0, 2, 3, 1, 4) and dtype == "float16": | |||
| with tik_instance.for_range(0, 32, block_num=32) as block_idx: | |||
| input_x_ub = tik_instance.Tensor(dtype, [1, 32, 7, 7, 16], name="input_1_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| transpose_ub = tik_instance.Tensor(dtype, [1, 7, 7, 32, 16], name="transpose_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| tik_instance.data_move(input_x_ub, input_x[block_idx, 0, 0, 0, 0], 0, 1, 1568, 0, 0) | |||
| with tik_instance.for_range(0, 7) as cc1: | |||
| with tik_instance.for_range(0, 7) as cc2: | |||
| with tik_instance.for_range(0, 32) as cc3: | |||
| tik_instance.vadds(16, transpose_ub[0, cc1, cc2, cc3, 0], input_x_ub[0, cc3, cc1, cc2, 0], 0, | |||
| 1, 1, 1, 0, 0) | |||
| tik_instance.data_move(res[block_idx * 25088], transpose_ub, 0, 1, 1568, 0, 0) | |||
| elif tuple(input_x_shape) == (32, 32, 14, 14, 16) and tuple(perm) == (0, 2, 3, 1, 4) and dtype == "float16": | |||
| def _inner_compute(split_index): | |||
| input_x_ub = tik_instance.Tensor(dtype, [1, 32, 2, 14, 16], name="input_1_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| transpose_ub = tik_instance.Tensor(dtype, [1, 2, 14, 32, 16], name="transpose_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| tik_instance.data_move(input_x_ub, input_x[block_idx, 0, split_index * 2, 0, 0], 0, 32, 28, 168, 0) | |||
| with tik_instance.for_range(0, 2) as cc2: | |||
| with tik_instance.for_range(0, 14) as cc3: | |||
| with tik_instance.for_range(0, 32) as cc4: | |||
| tik_instance.vadds(16, transpose_ub[0, cc2, cc3, cc4, 0], input_x_ub[0, cc4, cc2, cc3, 0], | |||
| 0, 1, 1, 1, 0, 0) | |||
| tik_instance.data_move(res[block_idx * 100352 + split_index * 2 * 7168], transpose_ub, 0, 1, 896, 0, 0) | |||
| with tik_instance.for_range(0, 32, block_num=32) as block_idx: | |||
| with tik_instance.for_range(0, 6, thread_num=2) as cc1: | |||
| _inner_compute(cc1) | |||
| _inner_compute(6) | |||
| elif tuple(input_x_shape) == (32, 64, 14, 14, 16) and tuple(perm) == (0, 2, 3, 1, 4) and dtype == "float16": | |||
| def _inner_compute(split_index, block_idx): | |||
| input_x_ub = tik_instance.Tensor(dtype, [1, 64, 2, 14, 16], name="input_1_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| transpose_ub = tik_instance.Tensor(dtype, [1, 2, 14, 64, 16], name="transpose_local_UB", | |||
| scope=tik.scope_ubuf) | |||
| tik_instance.data_move(input_x_ub, input_x[block_idx, 0, split_index * 2, 0, 0], 0, 64, 28, 168, 0) | |||
| with tik_instance.for_range(0, 2) as cc2: | |||
| with tik_instance.for_range(0, 14) as cc3: | |||
| with tik_instance.for_range(0, 64) as cc4: | |||
| tik_instance.vadds(16, transpose_ub[0, cc2, cc3, cc4, 0], input_x_ub[0, cc4, cc2, cc3, 0], | |||
| 0, 1, 1, 1, 0, 0) | |||
| tik_instance.data_move(res[block_idx * 200704 + split_index * 2 * 14336], transpose_ub, 0, 1, 1792, 0, 0) | |||
| with tik_instance.for_range(0, 32, block_num=32) as block_idx: | |||
| with tik_instance.for_range(0, 6, thread_num=2) as cc1: | |||
| _inner_compute(cc1, block_idx) | |||
| _inner_compute(6, block_idx) | |||
| tik_instance.BuildCCE(kernel_name, inputs=[input_x], outputs=[res]) | |||
| return tik_instance | |||
| @@ -1,76 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """batch_matmul_impl""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| @op_info_register("""{ | |||
| "op_name": "CusBatchMatMul", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "batchmatmul.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "CusBatchMatMul", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat" | |||
| ], | |||
| "name": "x1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat" | |||
| ], | |||
| "name": "x2", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=True, kernel_name="batchmatmul"): | |||
| """CusBatchMatMul""" | |||
| return | |||
| @@ -1,64 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """CusCholeskyTrsm""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| @op_info_register("""{ | |||
| "op_name": "CusCholeskyTrsm", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "choleskytrsm.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "CusCholeskyTrsm", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat" | |||
| ], | |||
| "name": "x1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| def CusCholeskyTrsm(input_x, output, kernel_name): | |||
| """CusCholeskyTrsm""" | |||
| return | |||
| @@ -1,69 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """CusFusedAbsMax1""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| @op_info_register("""{ | |||
| "op_name": "CusFusedAbsMax1", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "fusedabsmax1.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "CusFusedAbsMax1", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "origin_shape", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat" | |||
| ], | |||
| "name": "x1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_max1"): | |||
| """CusFusedAbsMax1""" | |||
| return | |||
| @@ -1,87 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """CusImg2ColNC1HWC0""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| @op_info_register("""{ | |||
| "op_name": "CusImg2ColNC1HWC0", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "img2colnc1hwc0.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "CusImg2ColNC1HWC0", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "ksizes", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "strides", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "dilates", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "padding", | |||
| "param_type": "required", | |||
| "type": "str", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0" | |||
| ], | |||
| "name": "x1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| def CusImg2ColNC1HWC0(input_x, output, ksizes, strides, dilates, padding, kernel_name="img2col"): | |||
| """CusImg2ColNC1HWC0""" | |||
| return | |||
| @@ -1,101 +0,0 @@ | |||
| # -*- coding:utf-8 -*- | |||
| """ | |||
| copyright 2020 Huawei Technologies Co., Ltd | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License == distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| matmul | |||
| """ | |||
| from __future__ import absolute_import | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from topi.cce import util | |||
| # General limitation of the size for input shape: 2**31 | |||
| SHAPE_SIZE_LIMIT = 2147483648 | |||
| NoneType = type(None) | |||
| @op_info_register("""{ | |||
| "op_name": "CusMatMulCubeDenseLeft", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "matmulcubedenseleft.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "CusMatMulCubeDenseLeft", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat" | |||
| ], | |||
| "name": "x1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ" | |||
| ], | |||
| "name": "x2", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat" | |||
| ], | |||
| "name": "x3", | |||
| "need_compile": false, | |||
| "param_type": "optional", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @util.check_input_type(dict, dict, (dict, NoneType), dict, bool, bool, str) | |||
| def CusMatMulCubeDenseLeft(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, | |||
| kernel_name="matmulcube"): | |||
| """CusMatMulCubeDenseLeft""" | |||
| return | |||
| @@ -1,102 +0,0 @@ | |||
| # -*- coding:utf-8 -*- | |||
| """ | |||
| copyright 2020 Huawei Technologies Co., Ltd | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License == distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| matmul | |||
| """ | |||
| from __future__ import absolute_import | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from topi.cce import util | |||
| # General limitation of the size for input shape: 2**31 | |||
| SHAPE_SIZE_LIMIT = 2147483648 | |||
| NoneType = type(None) | |||
| @op_info_register("""{ | |||
| "op_name": "CusMatMulCubeFraczLeftCast", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "matmulcubefraczleftcast.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "CusMatMulCubeFraczLeftCast", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat" | |||
| ], | |||
| "name": "x1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float32" | |||
| ], | |||
| "format": [ | |||
| "FracZ" | |||
| ], | |||
| "name": "x2", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat" | |||
| ], | |||
| "name": "x3", | |||
| "need_compile": false, | |||
| "param_type": "optional", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "FracZ" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| # pylint: disable=locally-disabled,too-many-arguments, too-many-locals, too-many-statements | |||
| @util.check_input_type(dict, dict, (dict, NoneType), dict, bool, bool, str) | |||
| def CusMatMulCubeFraczLeftCast(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, | |||
| kernel_name="CusMatMulCubeFraczLeftCast"): | |||
| """CusMatMulCubeFraczLeftCast""" | |||
| return | |||
| @@ -1,113 +0,0 @@ | |||
| #!/usr/bin/env python | |||
| # -*- coding:utf-8 -*- | |||
| """ | |||
| copyright 2020 Huawei Technologies Co., Ltd | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License == distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| matmul | |||
| """ | |||
| from __future__ import absolute_import | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| # General limitation of the size for input shape: 2**31 | |||
| SHAPE_SIZE_LIMIT = 2147483648 | |||
| NoneType = type(None) | |||
| @op_info_register("""{ | |||
| "op_name": "CusMatMulCubeFraczRightMul", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "matmulcubefraczrightmul.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "CusMatMulCubeFraczRightMul", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "FracZ" | |||
| ], | |||
| "name": "x1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat" | |||
| ], | |||
| "name": "x2", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat" | |||
| ], | |||
| "name": "x3", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 3, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat" | |||
| ], | |||
| "name": "x4", | |||
| "need_compile": false, | |||
| "param_type": "optional", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float32" | |||
| ], | |||
| "format": [ | |||
| "FracZ" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| def CusMatMulCubeFraczRightMul(input_x1, input_x2, input_x3, bias=None, output_y={}, trans_a=False, trans_b=False, | |||
| kernel_name="matmulcube"): | |||
| """CusMatMulCubeFraczRightMul""" | |||
| return | |||
| @@ -1,114 +0,0 @@ | |||
| #!/usr/bin/env python | |||
| # -*- coding:utf-8 -*- | |||
| """ | |||
| copyright 2020 Huawei Technologies Co., Ltd | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License == distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| matmul | |||
| """ | |||
| from __future__ import absolute_import | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from topi.cce import util | |||
| # General limitation of the size for input shape: 2**31 | |||
| SHAPE_SIZE_LIMIT = 2147483648 | |||
| NoneType = type(None) | |||
| @op_info_register("""{ | |||
| "op_name": "CusMatMulCube", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "matmulcube.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "CusMatMulCube", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "transpose_a", | |||
| "param_type": "required", | |||
| "type": "bool", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "transpose_b", | |||
| "param_type": "required", | |||
| "type": "bool", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ" | |||
| ], | |||
| "name": "x1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ" | |||
| ], | |||
| "name": "x2", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat" | |||
| ], | |||
| "name": "x3", | |||
| "need_compile": false, | |||
| "param_type": "optional", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float32" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| # pylint: disable=locally-disabled,too-many-arguments, too-many-locals, too-many-statements | |||
| @util.check_input_type(dict, dict, (dict, NoneType), dict, bool, bool, str) | |||
| def CusMatMulCube(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): | |||
| """CusMatMulCube""" | |||
| return | |||
| @@ -1,63 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """CusMatrixCombine""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| @op_info_register("""{ | |||
| "op_name": "CusMatrixCombine", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "matrixcombine.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "CusMatrixCombine", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat" | |||
| ], | |||
| "name": "x1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| def CusMatrixCombine(input_x, output, kernel_name="matrix_combine"): | |||
| """CusMatrixCombine""" | |||
| return | |||
| @@ -1,63 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """CusTranspose02314""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| @op_info_register("""{ | |||
| "op_name": "CusTranspose02314", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "transpose02314.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "CusTranspose02314", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0" | |||
| ], | |||
| "name": "x1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| def CusTranspose02314(input_x, output, kernel_name="transpose021354"): | |||
| """CusTranspose02314""" | |||
| return | |||
| @@ -70,6 +70,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, | |||
| from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey, CheckBprop | |||
| from . import _quant_ops | |||
| from ._quant_ops import * | |||
| from .thor_ops import * | |||
| __all__ = [ | |||
| 'TensorAdd', | |||
| @@ -262,5 +263,6 @@ __all__ = [ | |||
| "SquareSumAll" | |||
| ] | |||
| __all__.extend(thor_ops.__all__) | |||
| __all__.extend(_quant_ops.__all__) | |||
| __all__.sort() | |||
| @@ -17,13 +17,26 @@ import mindspore as ms | |||
| from mindspore.ops import prim_attr_register, PrimitiveWithInfer | |||
| from mindspore.ops.composite import multitype_ops as C | |||
| __all__ = ["CusBatchMatMul", | |||
| "CusCholeskyTrsm", | |||
| "CusFusedAbsMax1", | |||
| "CusImg2Col", | |||
| "CusMatMulCubeDenseLeft", | |||
| "CusMatMulCubeFraczRightMul", | |||
| "CusMatMulCube", | |||
| "CusMatrixCombine", | |||
| "CusTranspose02314", | |||
| "CusMatMulCubeDenseRight", | |||
| "CusMatMulCubeFraczLeftCast", | |||
| ] | |||
| class CusBatchMatMul(PrimitiveWithInfer): | |||
| """CusMatMulCube definition""" | |||
| """CusBatchMatMul definition""" | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init CusMatMulCube""" | |||
| """init CusBatchMatMul""" | |||
| self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) | |||
| def get_bprop(self): | |||
| @@ -61,11 +74,11 @@ class CusCholeskyTrsm(PrimitiveWithInfer): | |||
| class CusFusedAbsMax1(PrimitiveWithInfer): | |||
| """CusCholeskyTrsm definition""" | |||
| """CusFusedAbsMax1 definition""" | |||
| @prim_attr_register | |||
| def __init__(self, origin_shape=[-1, -1]): | |||
| """init CusCholeskyTrsm""" | |||
| """init CusFusedAbsMax1""" | |||
| self.init_prim_io_names(inputs=['x1'], outputs=['y']) | |||
| self.origin_shape = origin_shape | |||
| @@ -126,7 +139,7 @@ class CusMatMulCubeDenseLeft(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init CusMatMulCube""" | |||
| """init CusMatMulCubeDenseLeft""" | |||
| self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) | |||
| def get_bprop(self): | |||
| @@ -199,11 +212,11 @@ class CusMatMulCube(PrimitiveWithInfer): | |||
| class CusMatrixCombine(PrimitiveWithInfer): | |||
| """CusMatMulCube definition""" | |||
| """CusMatrixCombine definition""" | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init CusMatMulCube""" | |||
| """init CusMatrixCombine""" | |||
| self.init_prim_io_names(inputs=['x'], outputs=['y']) | |||
| def get_bprop(self): | |||
| @@ -246,3 +259,45 @@ class CusTranspose02314(PrimitiveWithInfer): | |||
| def infer_dtype(self, data1_dtype): | |||
| return data1_dtype | |||
| class CusMatMulCubeDenseRight(PrimitiveWithInfer): | |||
| """CusMatMulCubeDenseRight definition""" | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init CusMatMulCubeDenseRight""" | |||
| self.init_prim_io_names(inputs=['x1', 'x2', 'x3'], outputs=['y']) | |||
| def get_bprop(self): | |||
| def bprop(x1, x2, x3, out, dout): | |||
| return (C.zeros_like(x1), C.zeros_like(x2), C.zeros_like(x3)) | |||
| return bprop | |||
| def infer_shape(self, data1_shape, data2_shape, data3_shape): | |||
| return data1_shape | |||
| def infer_dtype(self, data1_dtype, data2_dtype, data3_dtype): | |||
| return ms.common.dtype.tensor_type(getattr(ms, "float32")) | |||
| class CusMatMulCubeFraczLeftCast(PrimitiveWithInfer): | |||
| """CusMatMulCubeFraczLeftCast definition""" | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init CusMatMulCubeFraczLeftCast""" | |||
| self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) | |||
| def get_bprop(self): | |||
| def bprop(x1, x2, out, dout): | |||
| return (C.zeros_like(x1), C.zeros_like(x2)) | |||
| return bprop | |||
| def infer_shape(self, data1_shape, data2_shape): | |||
| return data2_shape | |||
| def infer_dtype(self, data1_dtype, data2_dtype): | |||
| return ms.common.dtype.tensor_type(getattr(ms, "float16")) | |||