Merge pull request !2232 from wandongdong/r0.3tags/v0.3.1-alpha
| @@ -18,6 +18,7 @@ | |||
| from .. import operations as P | |||
| from .grad_base import bprop_getters | |||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | |||
| from ... import context | |||
| @bprop_getters.register(P.FakeQuantPerLayer) | |||
| @@ -64,12 +65,21 @@ def get_bprop_batchnorm_fold(self): | |||
| @bprop_getters.register(P.CorrectionMul) | |||
| def get_bprop_correction_mul(self): | |||
| """Generate bprop for CorrectionMul for Ascend and GPU""" | |||
| grad = P.CorrectionMulGrad(self.channel_axis) | |||
| grad_dx = P.CorrectionMulGrad(self.channel_axis) | |||
| grad_d_batch_std = P.CorrectionMulGradReduce(self.channel_axis) | |||
| def bprop(x, batch_std, running_std, out, dout): | |||
| dx, d_batch_std = grad(dout, x, batch_std, running_std) | |||
| dx, d_batch_std = grad_dx(dout, x, batch_std, running_std) | |||
| return dx, d_batch_std, zeros_like(running_std) | |||
| def bprop_npu(x, batch_std, running_std, out, dout): | |||
| dx, mul_dx = grad_dx(dout, x, batch_std, running_std) | |||
| d_batch_std = grad_d_batch_std(mul_dx) | |||
| return dx, d_batch_std, zeros_like(running_std) | |||
| if context.get_context('device_target') == "Ascend": | |||
| return bprop_npu | |||
| return bprop | |||
| @@ -37,7 +37,7 @@ correction_mul_grad_op_info = TBERegOp("CorrectionMulGrad") \ | |||
| .input(2, "batch_std", None, "required", None) \ | |||
| .input(3, "running_std", None, "required", None) \ | |||
| .output(0, "dx", True, "required", "all") \ | |||
| .output(1, "d_batch_std", True, "required", "all") \ | |||
| .output(1, "mul_dx", True, "required", "all") \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @@ -56,21 +56,14 @@ def correction_mul_grad_compute(dout, x, batch_std, running_std, channel, data_f | |||
| factor = te.lang.cce.vdiv(batch_std, running_std) | |||
| factor_b = te.lang.cce.broadcast(factor, shape_x) | |||
| dx = te.lang.cce.vmul(dout, factor_b) | |||
| mul_data = te.lang.cce.vmul(dout, x) | |||
| if channel == 0: | |||
| if data_format == "NCHW": | |||
| axis = [1, 2, 3] | |||
| else: | |||
| axis = [1, 2, 3, 4] | |||
| else: | |||
| axis = [2, 3] | |||
| red_data = te.lang.cce.sum(mul_data, axis, keepdims=True) | |||
| d_batch_std = te.lang.cce.vdiv(red_data, running_std) | |||
| return [dx, d_batch_std] | |||
| mul_dx = te.lang.cce.vmul(dout, x) | |||
| running_std_b = te.lang.cce.broadcast(running_std, shape_x) | |||
| mul_dx = te.lang.cce.vdiv(mul_dx, running_std_b) | |||
| return [dx, mul_dx] | |||
| @util.check_input_type(dict, dict, dict, dict, dict, dict, int, str) | |||
| def correction_mul_grad(dout, x, batch_std, running_std, dx, d_batch_std, channel, kernel_name="correction_mul_grad"): | |||
| def correction_mul_grad(dout, x, batch_std, running_std, dx, mul_dx, channel, kernel_name="correction_mul_grad"): | |||
| """CorrectionMulGrad op""" | |||
| shape_dout = dout.get("shape") | |||
| shape_x = dout.get("shape") | |||
| @@ -93,7 +86,7 @@ def correction_mul_grad(dout, x, batch_std, running_std, dx, d_batch_std, channe | |||
| util.compare_tensor_dict_key(dout, x, "shape") | |||
| util.compare_tensor_dict_key(dx, x, "shape") | |||
| util.compare_tensor_dict_key(batch_std, running_std, "shape") | |||
| util.compare_tensor_dict_key(batch_std, d_batch_std, "shape") | |||
| util.compare_tensor_dict_key(dx, mul_dx, "shape") | |||
| util.check_kernel_name(kernel_name) | |||
| util.check_shape_rule(shape_x) | |||
| @@ -120,7 +113,84 @@ def correction_mul_grad(dout, x, batch_std, running_std, dx, d_batch_std, channe | |||
| with tvm.target.cce(): | |||
| sch = generic.auto_schedule(res_list) | |||
| tensor_list = [dout_t, x_t, batch_std_t, running_std_t] + list(res_list) | |||
| tensor_list = [dout_t, x_t, batch_std_t, running_std_t] + res_list | |||
| config = {"print_ir": False, | |||
| "name": kernel_name, | |||
| "tensor_list": tensor_list} | |||
| te.lang.cce.cce_build_code(sch, config) | |||
| correction_mul_grad_reduce_op_info = TBERegOp("CorrectionMulGradReduce") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("correction_mul_grad_reduce.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("correction_mul_grad_reduce") \ | |||
| .partial_flag(True) \ | |||
| .op_pattern("formatAgnostic") \ | |||
| .attr("channel_axis", "optional", "int", "all") \ | |||
| .input(0, "dout", None, "required", None) \ | |||
| .output(0, "d_batch_std", True, "required", "all") \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register(correction_mul_grad_reduce_op_info) | |||
| def _correction_mul_grad_reduce_tbe(): | |||
| """CorrectionMulGradReduce TBE register""" | |||
| return | |||
| @fusion_manager.register("correction_mul_grad_reduce") | |||
| def correction_mul_grad_reduce_compute(mul_dx, channel, data_format, kernel_name="correction_mul"): | |||
| """CorrectionMulGradReduce compute""" | |||
| if channel == 0: | |||
| if data_format == "NCHW": | |||
| axis = [1, 2, 3] | |||
| else: | |||
| axis = [1, 2, 3, 4] | |||
| else: | |||
| axis = [2, 3] | |||
| d_batch_std = te.lang.cce.sum(mul_dx, axis, keepdims=True) | |||
| return d_batch_std | |||
| @util.check_input_type(dict, dict, int, str) | |||
| def correction_mul_grad_reduce(mul_dx, d_batch_std, channel, kernel_name="correction_mul_grad_reduce"): | |||
| """CorrectionMulGradReduce op""" | |||
| shape_dout = mul_dx.get("shape") | |||
| shape_x = mul_dx.get("shape") | |||
| dtype_dout = mul_dx.get("dtype") | |||
| inp_dtype_dout = dtype_dout.lower() | |||
| util.check_dtype_rule(inp_dtype_dout, ("float16", "float32")) | |||
| util.check_kernel_name(kernel_name) | |||
| util.check_shape_rule(shape_x) | |||
| util.check_shape_size(shape_x, SHAPE_SIZE_LIMIT) | |||
| data_format = mul_dx.get("format") | |||
| ori_format = mul_dx.get("format") | |||
| if data_format.upper() not in ("NC1HWC0", "NCHW"): | |||
| raise RuntimeError("Un supported data format {}".format(data_format)) | |||
| if data_format.upper() == "NCHW" and ori_format != "NCHW": | |||
| raise RuntimeError("data_format(NCHW) must same as ori_format") | |||
| shape_c = [1] * len(shape_x) | |||
| shape_c[channel] = d_batch_std.get("ori_shape")[0] | |||
| if data_format == "NC1HWC0" and channel == 1: | |||
| shape_c = d_batch_std.get("shape") | |||
| dout_t = tvm.placeholder(shape_dout, name="dout", dtype=inp_dtype_dout) | |||
| res = correction_mul_grad_reduce_compute(dout_t, channel, data_format, kernel_name) | |||
| with tvm.target.cce(): | |||
| sch = generic.auto_schedule(res) | |||
| tensor_list = [dout_t, res] | |||
| config = {"print_ir": False, | |||
| "name": kernel_name, | |||
| "tensor_list": tensor_list} | |||
| @@ -31,10 +31,12 @@ __all__ = ["FakeQuantPerLayer", | |||
| "BatchNormFoldGrad", | |||
| "CorrectionMul", | |||
| "CorrectionMulGrad", | |||
| "CorrectionMulGradReduce", | |||
| "BatchNormFold2", | |||
| "BatchNormFold2Grad", | |||
| "BatchNormFoldD", | |||
| "BatchNormFoldGradD", | |||
| "BNTrainingReduce", | |||
| "BatchNormFold2_D", | |||
| "BatchNormFold2GradD", | |||
| "BatchNormFold2GradReduce", | |||
| @@ -332,7 +334,7 @@ class BatchNormFold(PrimitiveWithInfer): | |||
| Batch normalization folded. | |||
| Args: | |||
| momentum (float): Momentum value should be [0, 1]. Default: 0.9. | |||
| momentum (float): Momentum value should be [0, 1]. Default: 0.1. | |||
| epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in | |||
| float32 else 1e-3. Default: 1e-5. | |||
| is_training (bool): In training mode set True, else set False. Default: True. | |||
| @@ -364,7 +366,7 @@ class BatchNormFold(PrimitiveWithInfer): | |||
| channel_axis = 1 | |||
| @prim_attr_register | |||
| def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0): | |||
| def __init__(self, momentum=0.1, epsilon=1e-5, is_training=True, freeze_bn=0): | |||
| """init batch norm fold layer""" | |||
| self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) | |||
| self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) | |||
| @@ -499,7 +501,7 @@ class CorrectionMulGrad(PrimitiveWithInfer): | |||
| from mindspore.ops._op_impl._custom_op import correction_mul_grad | |||
| self.channel_axis = channel_axis | |||
| self.init_prim_io_names(inputs=['dout', 'x', 'gamma', 'running_std'], | |||
| outputs=['dx', 'd_gamma']) | |||
| outputs=['dx', 'mul_dx']) | |||
| def infer_shape(self, dout_shape, x_shape, gamma_shape, running_std_shape): | |||
| validator.check("dout shape", dout_shape, "x_shape x", x_shape, Rel.EQ, self.name) | |||
| @@ -507,12 +509,45 @@ class CorrectionMulGrad(PrimitiveWithInfer): | |||
| Rel.EQ, self.name) | |||
| validator.check("running_std_shape[0]", running_std_shape[0], | |||
| "dout channel size", dout_shape[self.channel_axis], Rel.EQ, self.name) | |||
| if context.get_context('device_target') == "Ascend": | |||
| return x_shape, x_shape | |||
| return x_shape, gamma_shape | |||
| def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type): | |||
| args = {"dout": dout_type, "x": x_type, "gamma": gamma_type, "running_std": running_std_type} | |||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||
| return x_type, x_type | |||
| if context.get_context('device_target') == "Ascend": | |||
| return x_type, x_type | |||
| return x_type, gamma_type | |||
| class CorrectionMulGradReduce(PrimitiveWithInfer): | |||
| r""" | |||
| Performs grad reduce of CorrectionMul operation. | |||
| Examples: | |||
| >>> correction_mul_grad_rd = P.CorrectionMulGradReduce() | |||
| >>> dout = Tensor(np.array([1.5, -2.2, 0.7, -3, 1.6, 2.8]).reshape(2, 1, 1, 3), mindspore.float32) | |||
| >>> input_x = Tensor(np.random.randint(0, 256, (2, 1, 1, 3)), mindspore.float32) | |||
| >>> gamma = Tensor(np.array([0.2, -0.2, 2.5, -1.]).reshape(2, 1, 2), mindspore.float32) | |||
| >>> running_std = Tensor(np.array([1.2, 0.1, 0.7, 2.3]).reshape(2, 1, 2), mindspore.float32) | |||
| >>> result = correction_mul_grad_rd(dout, input_x, gamma, running_std) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, channel_axis=0): | |||
| """init correction mul reduce layer""" | |||
| if context.get_context('device_target') == "Ascend": | |||
| from mindspore.ops._op_impl._custom_op import correction_mul_grad | |||
| self.channel_axis = channel_axis | |||
| self.init_prim_io_names(inputs=['mul_dx'], | |||
| outputs=['d_gamma']) | |||
| def infer_shape(self, mul_dx_shape): | |||
| return [mul_dx_shape[self.channel_axis]] | |||
| def infer_dtype(self, mul_dx_type): | |||
| return mul_dx_type | |||
| class BatchNormFold2(PrimitiveWithInfer): | |||
| @@ -696,6 +731,32 @@ class BatchNormFoldGradD(PrimitiveWithInfer): | |||
| return x_type | |||
| class BNTrainingReduce(PrimitiveWithInfer): | |||
| """ | |||
| reduce sum at axis [0, 2, 3]. | |||
| Inputs: | |||
| - **x** (Tensor) - Tensor of shape :math:`(N, C)`. | |||
| Outputs: | |||
| - **x_sum** (Tensor) - Tensor has the same shape as x. | |||
| - **x_square_sum** (Tensor) - Tensor has the same shape as x. | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init _BNTrainingReduce layer""" | |||
| self.init_prim_io_names(inputs=['x'], | |||
| outputs=['x_sum', 'x_square_sum']) | |||
| def infer_shape(self, x_shape): | |||
| return [x_shape[1]], [x_shape[1]] | |||
| def infer_dtype(self, x_type): | |||
| return x_type, x_type | |||
| class BatchNormFold2_D(PrimitiveWithInfer): | |||
| """ | |||
| Scale the bias with a correction factor to the long term statistics | |||