From f34d06cd0f90b3a6db2b1d960d309e87a9281146 Mon Sep 17 00:00:00 2001 From: z00478463 Date: Tue, 26 May 2020 10:39:42 +0800 Subject: [PATCH] for pylint 4th --- example/resnet50_imagenet2012_THOR/config.py | 2 +- .../resnet50_imagenet2012_THOR/model/thor.py | 14 ++----- .../model/thor_layer.py | 38 +++++++------------ .../_op_impl/_custom_op/batch_matmul_impl.py | 5 ++- .../_op_impl/_custom_op/cholesky_trsm_impl.py | 1 + .../_custom_op/fused_abs_max1_impl.py | 9 +++-- .../ops/_op_impl/_custom_op/img2col_impl.py | 3 +- .../_custom_op/matmul_cube_dense_left_impl.py | 19 +++++----- .../matmul_cube_dense_right_impl.py | 13 ++++--- .../matmul_cube_fracz_left_cast_impl.py | 16 ++++---- .../matmul_cube_fracz_right_mul_impl.py | 9 +++-- .../_op_impl/_custom_op/matmul_cube_impl.py | 14 +++---- .../_custom_op/transpose02314_impl.py | 3 +- mindspore/ops/operations/thor_ops.py | 2 +- 14 files changed, 73 insertions(+), 75 deletions(-) diff --git a/example/resnet50_imagenet2012_THOR/config.py b/example/resnet50_imagenet2012_THOR/config.py index b9df4947aa..fc01287cc8 100644 --- a/example/resnet50_imagenet2012_THOR/config.py +++ b/example/resnet50_imagenet2012_THOR/config.py @@ -33,5 +33,5 @@ config = ed({ "save_checkpoint_path": "./", "label_smooth": 1, "label_smooth_factor": 0.1, - "frequency": 278 + "frequency": 834 }) diff --git a/example/resnet50_imagenet2012_THOR/model/thor.py b/example/resnet50_imagenet2012_THOR/model/thor.py index 44c0fd45db..d414f23851 100644 --- a/example/resnet50_imagenet2012_THOR/model/thor.py +++ b/example/resnet50_imagenet2012_THOR/model/thor.py @@ -22,12 +22,6 @@ from mindspore.nn.optim.optimizer import Optimizer from mindspore.ops import functional as F, composite as C, operations as P from mindspore.parallel._utils import _get_device_num, _get_mirror_mean -from cus_ops.cus_matmul_cube_dense_right import CusMatMulCubeDenseRight -from cus_ops.cus_matmul_cube_fracz_left_cast import CusMatMulCubeFraczLeftCast -from cus_ops.cus_matmul_cube_dense_left import CusMatMulCubeDenseLeft -from cus_ops.cus_matmul_cube_fracz_right_mul import CusMatMulCubeFraczRightMul -from model.grad_reducer_thor import DistributedGradReducerThor - momentum_opt = C.MultitypeFuncGraph("momentum_opt") @@ -68,10 +62,10 @@ class THOR(Optimizer): self.matrix_G = ParameterTuple(matrix_G) self.A_inv_max = ParameterTuple(A_inv_max) self.G_inv_max = ParameterTuple(G_inv_max) - self.cube_matmul_left = CusMatMulCubeFraczLeftCast() - self.cube_matmul_left_fc = CusMatMulCubeDenseLeft() - self.cube_matmul_right_fc = CusMatMulCubeDenseRight() - self.cube_matmul_right_mul = CusMatMulCubeFraczRightMul() + self.cube_matmul_left = P.CusMatMulCubeFraczLeftCast() + self.cube_matmul_left_fc = P.CusMatMulCubeDenseLeft() + self.cube_matmul_right_fc = P.CusMatMulCubeDenseRight() + self.cube_matmul_right_mul = P.CusMatMulCubeFraczRightMul() self.transpose = P.Transpose() self.shape = P.Shape() self.reshape = P.Reshape() diff --git a/example/resnet50_imagenet2012_THOR/model/thor_layer.py b/example/resnet50_imagenet2012_THOR/model/thor_layer.py index 8097d729ea..fea74605b6 100644 --- a/example/resnet50_imagenet2012_THOR/model/thor_layer.py +++ b/example/resnet50_imagenet2012_THOR/model/thor_layer.py @@ -23,19 +23,9 @@ from mindspore.common.tensor import Tensor from mindspore.nn.cell import Cell from mindspore.nn.layer.activation import get_activation from mindspore.ops import operations as P - -from cus_ops.cus_batch_matmul import CusBatchMatMul -from cus_ops.cus_cholesky_trsm import CusCholeskyTrsm -from cus_ops.cus_fused_abs_max1 import CusFusedAbsMax1 -from cus_ops.cus_img2col import CusImg2Col -from cus_ops.cus_matmul_cube import CusMatMulCube -from cus_ops.cus_matrix_combine import CusMatrixCombine -from cus_ops.cus_transpose02314 import CusTranspose02314 - import numpy as np C0 = 16 - def caculate_device_shape(matrix_dim, channel, is_A): ll = (0) if is_A: @@ -153,11 +143,11 @@ class Conv2d_Thor(_Conv): group=self.group ) - self.img2col = CusImg2Col(ksizes=ksizes, strides=strides) - self.cube_matmul = CusMatMulCube(transpose_a=True) - self.matrix_combine = CusMatrixCombine() - self.cholesky = CusCholeskyTrsm() - self.transpose02314 = CusTranspose02314() + self.img2col = P.CusImg2Col(ksizes=ksizes, strides=strides) + self.cube_matmul = P.CusMatMulCube(transpose_a=True) + self.matrix_combine = P.CusMatrixCombine() + self.cholesky = P.CusCholeskyTrsm() + self.transpose02314 = P.CusTranspose02314() self.matrix_A_dim = self.in_channels * self.kernel_size[0] * self.kernel_size[1] self.matrix_G_dim = self.out_channels self.matrix_A_device_shape, self.matrix_A_device_dim = caculate_device_shape(self.matrix_A_dim, @@ -190,7 +180,7 @@ class Conv2d_Thor(_Conv): self.mul = P.Mul() self.cast = P.Cast() self.damping = Tensor(damping) - self.vector_matmul = CusBatchMatMul() + self.vector_matmul = P.CusBatchMatMul() self.diag_block_dim = 128 self.channels_slice_flag = False if self.in_channels % C0 != 0: @@ -221,8 +211,8 @@ class Conv2d_Thor(_Conv): self.dampingA = Tensor(np.identity(dampingA_dim), mstype.float32) self.dampingG = Tensor(np.identity(dampingG_dim), mstype.float32) - self.fused_abs_max1 = CusFusedAbsMax1([self.matrix_A_dim, self.matrix_A_dim]) - self.fused_abs_max2 = CusFusedAbsMax1() + self.fused_abs_max1 = P.CusFusedAbsMax1([self.matrix_A_dim, self.matrix_A_dim]) + self.fused_abs_max2 = P.CusFusedAbsMax1() self.log = P.Log() self.exp = P.Exp() self.sqrt = P.Sqrt() @@ -375,9 +365,9 @@ class Dense_Thor(Cell): self.fake_G = Tensor(np.zeros([63, 63, 16, 16]).astype(np.float16)) self.matmul = P.MatMul(transpose_b=True) - self.cube_matmul = CusMatMulCube(transpose_a=True) - self.matrix_combine = CusMatrixCombine() - self.cholesky = CusCholeskyTrsm() + self.cube_matmul = P.CusMatMulCube(transpose_a=True) + self.matrix_combine = P.CusMatrixCombine() + self.cholesky = P.CusCholeskyTrsm() self.shape = P.Shape() self.reshape = P.Reshape() self.transpose = P.Transpose() @@ -386,7 +376,7 @@ class Dense_Thor(Cell): self.cast = P.Cast() self.damping = Tensor(damping) self.loss_scale = Tensor(1 / loss_scale, mstype.float16) - self.vector_matmul = CusBatchMatMul() + self.vector_matmul = P.CusBatchMatMul() self.pad = P.Pad(((0, 24), (0, 24))) self.pad1 = P.Pad(((0, 8), (0, 8))) self.slice = P.Slice() @@ -396,8 +386,8 @@ class Dense_Thor(Cell): self.axis = 0 self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False) self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False) - self.fused_abs_max1 = CusFusedAbsMax1([1000, 1000]) - self.fused_abs_max2 = CusFusedAbsMax1() + self.fused_abs_max1 = P.CusFusedAbsMax1([1000, 1000]) + self.fused_abs_max2 = P.CusFusedAbsMax1() self.log = P.Log() self.exp = P.Exp() self.dampingA = Tensor(np.identity(2048), mstype.float32) diff --git a/mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py b/mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py index d8395c1e81..97982c53cf 100644 --- a/mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +++ b/mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py @@ -33,6 +33,7 @@ cus_batchmatmul_op_info = TBERegOp("CusBatchMatMul") \ def _get_flattern_shape(shape): + """_get_flattern_shape""" flattern_shape = 1 for dim in shape: flattern_shape *= dim @@ -40,6 +41,7 @@ def _get_flattern_shape(shape): def _inner_matmul_new(tik_instance, dtype, input1, input1_index, input2, input2_index, res, res_index): + """_inner_matmul_new""" input_1_local_UB = tik_instance.Tensor(dtype, [128], name="input_1_local_UB", scope=tik.scope_ubuf) t_1_0_local_UB = tik_instance.Tensor(dtype, [64 * 128], name="t_1_0_local_UB", scope=tik.scope_ubuf) tik_instance.data_move(input_1_local_UB, input1[input1_index], 0, 1, 16, 0, 0) @@ -71,6 +73,7 @@ def _inner_matmul_new(tik_instance, dtype, input1, input1_index, input2, input2_ def _inner_matmul_new_1_64_32_64(tik_instance, dtype, input1, input1_index, input2, input2_index, res, res_index): + """_inner_matmul_new_1_64_32_64""" input_1_local_UB = tik_instance.Tensor(dtype, [64], name="input_1_local_UB", scope=tik.scope_ubuf) tik_instance.data_move(input_1_local_UB, input1[input1_index], 0, 1, 8, 0, 0) with tik_instance.for_range(0, 2, thread_num=2) as thread_idx2: @@ -90,6 +93,7 @@ def _inner_matmul_new_1_64_32_64(tik_instance, dtype, input1, input1_index, inpu @op_info_register(cus_batchmatmul_op_info) def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=True, kernel_name="batchmatmul"): + """CusBatchMatMul""" if util.get_product_version() == util.VERSION_MINI: tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) else: @@ -116,7 +120,6 @@ def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=Tr # if not transpose_a and transpose_b: batch, m, k = x1_shape - _, n, _ = x2_shape input1_shape = _get_flattern_shape(x1_shape) input1 = tik_instance.Tensor(dtype, input1_shape, name="input1", scope=tik.scope_gm) diff --git a/mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py b/mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py index 50830fe0f6..71dd1ccb2d 100644 --- a/mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +++ b/mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py @@ -32,6 +32,7 @@ cus_cholesky_trsm_op_info = TBERegOp("CusCholeskyTrsm") \ @op_info_register(cus_cholesky_trsm_op_info) def CusCholeskyTrsm(input_x, output, kernel_name): + """CusCholeskyTrsm""" input_x_shape = input_x.get("shape") output_shape = output.get("shape") split_dim = 128 diff --git a/mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py b/mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py index 0c47ce78b1..f4b8d44063 100644 --- a/mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +++ b/mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py @@ -33,6 +33,7 @@ cus_fused_abs_max1_op_info = TBERegOp("CusFusedAbsMax1") \ @op_info_register(cus_fused_abs_max1_op_info) def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_max1"): + """CusFusedAbsMax1""" input_x_shape = input_x.get("shape") output_shape = output.get("shape") @@ -203,9 +204,9 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) elif (input_x_shape[0] == 8 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( input_x_shape[0] == 32 and input_x_shape[1] == 16) or ( - input_x_shape[0] == 16 and input_x_shape[1] == 32): + input_x_shape[0] == 16 and input_x_shape[1] == 32): if (input_x_shape[0] == 8 and input_x_shape[1] == 128 and input_x_shape[2] == 128) and origin_shape[ - 0] == 1000: + 0] == 1000: input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) blocks = 32 @@ -257,7 +258,7 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) elif (input_x_shape[0] == 8 and input_x_shape[1] == 128 and input_x_shape[2] == 128) and origin_shape[ - 0] == 1001: + 0] == 1001: input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) blocks = 32 @@ -350,7 +351,7 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) elif (input_x_shape[0] == 16 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( input_x_shape[0] == 16 and input_x_shape[1] == 64) or ( - input_x_shape[0] == 64 and input_x_shape[1] == 16): + input_x_shape[0] == 64 and input_x_shape[1] == 16): input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) total_elements = 1 diff --git a/mindspore/ops/_op_impl/_custom_op/img2col_impl.py b/mindspore/ops/_op_impl/_custom_op/img2col_impl.py index 8c1fd1262f..433e335565 100644 --- a/mindspore/ops/_op_impl/_custom_op/img2col_impl.py +++ b/mindspore/ops/_op_impl/_custom_op/img2col_impl.py @@ -36,6 +36,7 @@ cus_img2col_info = TBERegOp("CusImg2Col") \ @op_info_register(cus_img2col_info) def CusImg2Col(input_x, output, ksizes, strides, dilates, mode, kernel_name="img2col"): + """CusImg2Col""" input_x_shape = input_x.get("shape") input_x_dtype = input_x.get("dtype") N, C1, H, W, C0 = input_x_shape @@ -64,7 +65,7 @@ def CusImg2Col(input_x, output, ksizes, strides, dilates, mode, kernel_name="img ((32, 8, 28, 28, 16), 'float16', (1, 1), (1, 1)), ((32, 32, 28, 28, 16), 'float16', (1, 1), (1, 1)), ((32, 16, 14, 14, 16), 'float16', (1, 1), (1, 1)), - ((32, 16, 56, 56, 16), 'float16', (1, 1), (1, 1)), ] + ((32, 16, 56, 56, 16), 'float16', (1, 1), (1, 1)),] if input_shape not in supported_shape: raise RuntimeError("input_shape %s is not supported" % str(input_shape)) diff --git a/mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py b/mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py index 0458363a6d..2d70263bc1 100644 --- a/mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +++ b/mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py @@ -17,10 +17,9 @@ limitations under the License. matmul """ from __future__ import absolute_import - +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType import te.lang.cce import te.platform.cce_params as cce -from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType from te import tik from te import tvm from topi import generic @@ -128,7 +127,7 @@ def _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b): if n_shape % cce.BLOCK_IN != 0 and n_shape != 1: raise RuntimeError("input shape N should be 1 or multiple of %d" % cce.BLOCK_IN) - if len(shape_bias): + if len(shape_bias) != 0: if len(shape_bias) == 1: if is_gevm or is_gemv: if shape_bias[0] != m_shape * n_shape: @@ -145,16 +144,19 @@ def _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b): def _get_bias(shape_bias): bias_length = shape_bias[0] + shb = [] if bias_length % 16 == 0: - return shape_bias + shb = shape_bias else: bias_length = (bias_length // 16) * 16 + 16 shape_bias = [] shape_bias.append(bias_length) - return shape_bias + shb = shape_bias + return shb def _get_input_shape(shape_x): + """_get_input_shape""" dim_a = shape_x[0] dim_b = shape_x[1] res = [] @@ -173,6 +175,7 @@ def _get_input_shape(shape_x): def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): + """check_supported""" shape_a = input_x1.get("shape") shape_b = input_x2.get("shape") print("shape_a: ", shape_a) @@ -183,8 +186,6 @@ def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, t util.check_shape_rule(shape_b) util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) - if bias is not None and bool(bias): - shape_bias = bias.get("shape") try: trans_a_f = bool(1 - trans_a) if src_dtype == "float32" or src_dtype == "int32": @@ -250,7 +251,7 @@ def CusMatMulCubeDenseLeft(input_x1, input_x2, bias=None, output_y={}, trans_a=F """ calculating matrix multiplication with bias, C = A*B + bias, support input data with fractal format. - + Parameters: shape_a: list or tuple Shape of the first tensor a with rank > 1 @@ -269,7 +270,7 @@ def CusMatMulCubeDenseLeft(input_x1, input_x2, bias=None, output_y={}, trans_a=F If True, the input data format of a and b must be fractal format shape_bias: list or tuple Shape of bias, only support the input data format with ND - + Returns ------- None diff --git a/mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py b/mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py index 5cae9afda0..4a1982738d 100644 --- a/mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +++ b/mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py @@ -2,19 +2,19 @@ # -*- coding:utf-8 -*- """ copyright 2020 Huawei Technologies Co., Ltd - + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 - + Unless required by applicable law or agreed to in writing, software distributed under the License == distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. - + matmul """ from __future__ import absolute_import @@ -43,11 +43,12 @@ matmul_cube_dense_right_op_info = TBERegOp("CusMatMulCubeDenseRight") \ @op_info_register(matmul_cube_dense_right_op_info) def CusMatMulCubeDenseRight(input_x1, input_x2, input_x3, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): + """CusMatMulCubeDenseRight""" shape_a_temp = (128, 63, 16, 16) shape_b_temp = (128, 128, 16, 16) shape_output = output_y.get("shape") matrix_max_shape = (1,) - support_shape = [(shape_a_temp, shape_b_temp, matrix_max_shape), ] + support_shape = [(shape_a_temp, shape_b_temp, matrix_max_shape),] shape_a_input = input_x1.get("shape") shape_b_input = input_x2.get("shape") matrix_max_input = input_x3.get("shape") @@ -62,7 +63,7 @@ def CusMatMulCubeDenseRight(input_x1, input_x2, input_x3, bias=None, output_y={} tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) input_x1 = tik_instance.Tensor("float16", shape_a_temp, name="left_matrix", scope=tik.scope_gm) input_x2 = tik_instance.Tensor("float16", shape_b_temp, name="right_matrix", scope=tik.scope_gm) - input_x3 = tik_instance.Tensor("float32", [1, ], name="matrix_max", scope=tik.scope_gm) + input_x3 = tik_instance.Tensor("float32", [1,], name="matrix_max", scope=tik.scope_gm) resMatmul = tik_instance.Tensor("float32", shape_output, name="output", scope=tik.scope_gm) with tik_instance.for_range(0, 32, block_num=32) as block_index: core_m_idx = block_index // 16 diff --git a/mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py b/mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py index ebff84d889..817aeb91d4 100644 --- a/mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +++ b/mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py @@ -17,9 +17,8 @@ limitations under the License. matmul """ from __future__ import absolute_import - -import te.platform.cce_params as cce from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType +import te.platform.cce_params as cce from te import tik from topi.cce import util @@ -141,6 +140,7 @@ src_dtype: str def _get_bias(shape_bias): + """_get_bias""" bias_length = shape_bias[0] if bias_length % 16 == 0: return shape_bias @@ -152,6 +152,7 @@ def _get_bias(shape_bias): def _get_input_shape(shape_x): + """_get_input_shape""" dim_a = shape_x[0] dim_b = shape_x[1] res = [] @@ -170,6 +171,7 @@ def _get_input_shape(shape_x): def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): + """check_supported""" shape_a = input_x1.get("shape") shape_b = input_x2.get("shape") print("shape_a: ", shape_a) @@ -180,8 +182,6 @@ def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, t util.check_shape_rule(shape_b) util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) - if bias is not None and bool(bias): - shape_bias = bias.get("shape") try: trans_a_f = bool(1 - trans_a) if src_dtype == "float32" or src_dtype == "int32": @@ -265,7 +265,7 @@ def CusMatMulCubeFraczLeftCast(input_x1, input_x2, bias=None, output_y={}, trans If True, the input data format of a and b must be fractal format shape_bias: list or tuple Shape of bias, only support the input data format with ND - + Returns ------- None @@ -381,6 +381,7 @@ def CusMatMulCubeFraczLeftCast(input_x1, input_x2, bias=None, output_y={}, trans def get_cus_tile_info(input_x1, input_x2, diag_size): + """get_cus_tile_info""" tile_map = { ((32, 32, 16, 16), (128, 32, 16, 16)): (8, 8, 16), ((8, 8, 16, 16), (72, 8, 16, 16)): (8, 8, 4), @@ -415,8 +416,9 @@ def get_cus_tile_info(input_x1, input_x2, diag_size): def cus_cube_matmul_cast(tik_instance, input_x1, trans_a, input_x2, trans_b, res, mo_tile, ko_tile, no_tile, diag_opt=False, diag_size=128): - ko, mo, mi, ki = input_x1.shape - no, ko, ki, ni = input_x2.shape + """cus_cube_matmul_cast""" + ko, mo, _, _ = input_x1.shape + no, ko, ki, _ = input_x2.shape c0 = input_x1.shape[-1] diag_outer = diag_size // c0 maxblocknum = 32 diff --git a/mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py b/mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py index b5f8ee9d82..e30b19ef6f 100644 --- a/mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +++ b/mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py @@ -47,6 +47,7 @@ cus_matmul_cube_fracz_right_mul_op_info = TBERegOp("CusMatMulCubeFraczRightMul") @op_info_register(cus_matmul_cube_fracz_right_mul_op_info) def CusMatMulCubeFraczRightMul(input_x1, input_x2, input_x3, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): + """CusMatMulCubeFraczRightMul""" if util.get_product_version() == util.VERSION_MINI: tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) else: @@ -80,7 +81,7 @@ def CusMatMulCubeFraczRightMul(input_x1, input_x2, input_x3, bias=None, output_y ((64, 32, 16, 16), 'float16', (64, 64, 16, 16), 'float16', (1,), 'float32'), ((16, 64, 16, 16), 'float16', (16, 16, 16, 16), 'float16', (1,), 'float32')] input_shape = ( - tuple(input_x1_shape), input_x1_dtype, tuple(input_x2_shape), input_x2_dtype, tuple(input_x3_shape), input_x3_dtype) + tuple(input_x1_shape), input_x1_dtype, tuple(input_x2_shape), input_x2_dtype, tuple(input_x3_shape), input_x3_dtype) if input_shape not in Supported: raise RuntimeError("input_shape %s is not supported" % str(input_shape)) @@ -95,15 +96,17 @@ def CusMatMulCubeFraczRightMul(input_x1, input_x2, input_x3, bias=None, output_y def cus_cube_matmul_right_mul(tik_instance, input_x1, input_x2, input_x3, res): + """cus_cube_matmul_right_mul""" diag_size = 128 - ko, mo, mi, ki = input_x1.shape - no, ko, ki, ni = input_x2.shape + ko, mo, _, _ = input_x1.shape + no, ko, ki, _ = input_x2.shape c0 = input_x1.shape[-1] diag_outer = diag_size // c0 if [input_x1.shape[-1], input_x1.shape[-2], input_x2.shape[-1], input_x2.shape[-2]] != [c0, c0, c0, c0]: raise ValueError("shape of input_x1 or input_x2 is not supported!") def get_cus_tile_info(input_x1, input_x2, input_x3): + """get_cus_tile_info""" input_shape = (tuple(input_x1.shape), input_x1.dtype, tuple(input_x2.shape), input_x2.dtype, tuple(input_x3.shape), input_x3.dtype) tile_map = { diff --git a/mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py b/mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py index dfa83c4fb7..603ed287f6 100644 --- a/mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +++ b/mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py @@ -18,11 +18,10 @@ limitations under the License. matmul """ from __future__ import absolute_import - -import te.lang.cce -import te.platform.cce_params as cce from impl.matmul_vector import matmul_vector_cce from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType +import te.lang.cce +import te.platform.cce_params as cce from te import tvm from topi import generic from topi.cce import util @@ -146,6 +145,7 @@ def _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b): def _get_bias(shape_bias): + """_get_bias""" bias_length = shape_bias[0] if bias_length % 16 == 0: return shape_bias @@ -157,6 +157,7 @@ def _get_bias(shape_bias): def _get_input_shape(shape_x): + """_get_input_shape""" dim_a = shape_x[0] dim_b = shape_x[1] res = [] @@ -175,6 +176,7 @@ def _get_input_shape(shape_x): def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): + """check_supported""" shape_a = input_x1.get("shape") shape_b = input_x2.get("shape") print("shape_a: ", shape_a) @@ -185,8 +187,6 @@ def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, t util.check_shape_rule(shape_b) util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) - if bias is not None and bool(bias): - shape_bias = bias.get("shape") try: trans_a_f = bool(1 - trans_a) if src_dtype == "float32" or src_dtype == "int32": @@ -250,7 +250,7 @@ def CusMatMulCube(input_x1, input_x2, bias=None, output_y={}, trans_a=False, tra """ calculating matrix multiplication with bias, C = A*B + bias, support input data with fractal format. - + Parameters: shape_a: list or tuple Shape of the first tensor a with rank > 1 @@ -269,7 +269,7 @@ def CusMatMulCube(input_x1, input_x2, bias=None, output_y={}, trans_a=False, tra If True, the input data format of a and b must be fractal format shape_bias: list or tuple Shape of bias, only support the input data format with ND - + Returns ------- None diff --git a/mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py b/mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py index f341efe4b7..141e2c1d51 100644 --- a/mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +++ b/mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py @@ -32,6 +32,7 @@ cus_transpose02314_op_info = TBERegOp("CusTranspose02314") \ @op_info_register(cus_transpose02314_op_info) def CusTranspose02314(input_x, output, kernel_name="transpose021354"): + """CusTranspose02314""" input_x_shape = input_x.get("shape") output_shape = output.get("shape") perm = (0, 2, 3, 1, 4) @@ -263,7 +264,7 @@ def CusTranspose02314(input_x, output, kernel_name="transpose021354"): with tik_instance.for_range(0, 32, block_num=32) as block_idx: with tik_instance.for_range(0, 6, thread_num=2) as cc1: - _inner_ + compute(cc1) + _inner_compute(cc1) _inner_compute(6) elif tuple(input_x_shape) == (32, 64, 14, 14, 16) and tuple(perm) == (0, 2, 3, 1, 4) and dtype == "float16": def _inner_compute(split_index, block_idx): diff --git a/mindspore/ops/operations/thor_ops.py b/mindspore/ops/operations/thor_ops.py index e48180c3f6..5e6ff4b959 100644 --- a/mindspore/ops/operations/thor_ops.py +++ b/mindspore/ops/operations/thor_ops.py @@ -91,7 +91,7 @@ class CusFusedAbsMax1(PrimitiveWithInfer): def infer_shape(self, data1_shape): ll = [] if len(data1_shape) == 2: - ll = [1, ] + ll = [1,] else: ll = [32, 64] return ll