| @@ -22,11 +22,15 @@ from mindspore.common.parameter import Parameter | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore._checkparam import check_int_positive, check_bool, twice | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore.nn.layer.activation import get_activation | |||
| import mindspore.context as context | |||
| __all__ = [ | |||
| 'FakeQuantWithMinMax', | |||
| 'DepthwiseConv2dBatchNormQuant', | |||
| 'Conv2dBatchNormQuant', | |||
| 'Conv2dQuant', | |||
| 'DenseQuant', | |||
| @@ -39,6 +43,169 @@ __all__ = [ | |||
| ] | |||
| class BatchNormFoldCell(Cell): | |||
| """ | |||
| Batch normalization folded. | |||
| Args: | |||
| momentum (float): Momentum value should be [0, 1]. Default: 0.1. | |||
| epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in | |||
| float32 else 1e-3. Default: 1e-5. | |||
| freeze_bn (int): Delay in steps at which computation switches from regular batch | |||
| norm to frozen mean and std. Default: 0. | |||
| Inputs: | |||
| - **x** (Tensor) - Tensor of shape :math:`(N, C, H, W)`. | |||
| - **mean** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **variance** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **global_step** (Tensor) - Tensor to record current global step. | |||
| Outputs: | |||
| Tuple of 4 Tensor, the normalized input and the updated parameters. | |||
| - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **running_std** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| """ | |||
| def __init__(self, momentum=0.9, epsilon=1e-5, freeze_bn=0): | |||
| """init batch norm fold layer""" | |||
| super(BatchNormFoldCell, self).__init__() | |||
| self.epsilon = epsilon | |||
| self.is_gpu = context.get_context('device_target') == "GPU" | |||
| if self.is_gpu: | |||
| self.bn_train = P.BatchNormFold(momentum, epsilon, is_training=True, freeze_bn=freeze_bn) | |||
| self.bn_infer = P.BatchNormFold(momentum, epsilon, is_training=False, freeze_bn=freeze_bn) | |||
| else: | |||
| self.bn_reduce = P.BNTrainingReduce() | |||
| self.bn_update = P.BatchNormFoldD(momentum, epsilon, is_training=True, freeze_bn=freeze_bn) | |||
| def construct(self, x, mean, variance, global_step): | |||
| if self.is_gpu: | |||
| if self.training: | |||
| batch_mean, batch_std, running_mean, running_std = self.bn_train(x, mean, variance, global_step) | |||
| else: | |||
| batch_mean, batch_std, running_mean, running_std = self.bn_infer(x, mean, variance, global_step) | |||
| else: | |||
| if self.training: | |||
| x_sum, x_square_sum = self.bn_reduce(x) | |||
| _, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated = \ | |||
| self.bn_update(x, x_sum, x_square_sum, mean, variance) | |||
| P.Assign()(mean, mean_updated) | |||
| P.Assign()(variance, variance_updated) | |||
| else: | |||
| batch_mean = P.ZerosLike()(variance) | |||
| batch_std = P.OnesLike()(variance) | |||
| running_mean = P.TensorAdd()(mean, 0.) | |||
| running_std = P.Sqrt()(P.TensorAdd()(variance, self.epsilon)) | |||
| return batch_mean, batch_std, running_mean, running_std | |||
| class FakeQuantWithMinMaxD(Cell): | |||
| r""" | |||
| Aware Quantization training op of ascend. This OP provide Fake quantization observer | |||
| function on data with min and max. | |||
| Args: | |||
| min_init (int, list): The dimension of channel or 1(layer). Default: -6. | |||
| max_init (int, list): The dimension of channel or 1(layer). Default: 6. | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| ema (bool): Exponential Moving Average algorithm update min and max. Default: False. | |||
| ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.9999. | |||
| per_channel (bool): Quantization by layer or channel. Default: False. | |||
| out_channels (int): declarate the min and max channel size, Default: 1. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||
| Inputs: | |||
| - **x** (Tensor) - The input of FakeQuantWithMinMax. | |||
| Outputs: | |||
| Tensor, with the same type and shape as the `x`. | |||
| Examples: | |||
| >>> fake_quant = nn.FakeQuantWithMinMaxD() | |||
| >>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) | |||
| >>> result = fake_quant(input_x) | |||
| """ | |||
| def __init__(self, | |||
| min_init=-6, | |||
| max_init=6, | |||
| num_bits=8, | |||
| ema=False, | |||
| ema_decay=0.999, | |||
| per_channel=False, | |||
| channel_size=1, | |||
| quant_delay=0, | |||
| symmetric=False, | |||
| narrow_range=False, | |||
| training=True): | |||
| """init FakeQuantWithMinMax ascend layer""" | |||
| super(FakeQuantWithMinMaxD, self).__init__() | |||
| self.min_init = min_init | |||
| self.num_bits = num_bits | |||
| self.max_init = max_init | |||
| self.ema = ema | |||
| self.ema_decay = ema_decay | |||
| self.per_channel = per_channel | |||
| self.channel_size = channel_size | |||
| self.quant_delay = quant_delay | |||
| self.symmetric = symmetric | |||
| self.narrow_range = narrow_range | |||
| self.training = training | |||
| if not per_channel: | |||
| self.fake_quant = P.FakeQuantWithMinMax(num_bits=self.num_bits, | |||
| ema=self.ema, | |||
| ema_decay=self.ema_decay, | |||
| quant_delay=self.quant_delay, | |||
| symmetric=self.symmetric, | |||
| narrow_range=self.narrow_range, | |||
| training=training) | |||
| self.ema_update = P.FakeQuantWithMinMaxUpdate(num_bits=self.num_bits, | |||
| ema=self.ema, | |||
| ema_decay=self.ema_decay, | |||
| quant_delay=self.quant_delay, | |||
| symmetric=self.symmetric, | |||
| narrow_range=self.narrow_range, | |||
| training=training) | |||
| else: | |||
| raise RuntimeError("not support per channel") | |||
| if isinstance(min_init, Parameter): | |||
| self.minq = min_init | |||
| self.maxq = max_init | |||
| else: | |||
| self.minq = Parameter(Tensor(np.array([min_init]).astype(np.float32)), | |||
| name='quant_min', | |||
| requires_grad=False) | |||
| self.maxq = Parameter(Tensor(np.array([max_init]).astype(np.float32)), | |||
| name='quant_max', | |||
| requires_grad=False) | |||
| self.reduce_min = P.ReduceMin() | |||
| self.reduce_max = P.ReduceMax() | |||
| def extend_repr(self): | |||
| s = 'min_init={}, max_init={}, ema={}, ema_decay={}, per_channel={}, channel_size={}, quant_delay={}'.format( | |||
| self.min_init, self.max_init, self.ema, self.ema_decay, self.per_channel, self.channel_size, | |||
| self.quant_delay) | |||
| return s | |||
| def construct(self, x, minq, maxq): | |||
| if self.training: | |||
| min_up, max_up = self.ema_update(x, minq, maxq) | |||
| out = self.fake_quant(x, min_up, max_up) | |||
| P.Assign()(self.minq, min_up) | |||
| P.Assign()(self.maxq, max_up) | |||
| else: | |||
| out = self.fake_quant(x, minq, maxq) | |||
| return out | |||
| class FakeQuantWithMinMax(Cell): | |||
| r""" | |||
| Aware Quantization training op. This OP provide Fake quantization observer function on data with min and max. | |||
| @@ -62,7 +229,7 @@ class FakeQuantWithMinMax(Cell): | |||
| Tensor, with the same type and shape as the `x`. | |||
| Examples: | |||
| >>> fake_quant = nn.FakeQuantWithMinMax() | |||
| >>> fake_quant = FakeQuantWithMinMax() | |||
| >>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) | |||
| >>> result = fake_quant(input_x) | |||
| """ | |||
| @@ -77,7 +244,9 @@ class FakeQuantWithMinMax(Cell): | |||
| out_channels=1, | |||
| quant_delay=0, | |||
| symmetric=False, | |||
| narrow_range=False): | |||
| narrow_range=False, | |||
| training=True): | |||
| """init FakeQuantWithMinMax layer""" | |||
| super(FakeQuantWithMinMax, self).__init__() | |||
| self.min_init = min_init | |||
| @@ -90,12 +259,13 @@ class FakeQuantWithMinMax(Cell): | |||
| self.quant_delay = quant_delay | |||
| self.symmetric = symmetric | |||
| self.narrow_range = narrow_range | |||
| self.training = training | |||
| if per_channel: | |||
| min_array = np.array([self.min_init for i in range( | |||
| 0, self.out_channels)]).astype(np.float32) | |||
| max_array = np.array([self.max_init for i in range( | |||
| 0, self.out_channels)]).astype(np.float32) | |||
| min_array = np.array([self.min_init for i in range(0, self.out_channels)]).astype(np.float32) | |||
| max_array = np.array([self.max_init for i in range(0, self.channel_size)]).astype(np.float32) | |||
| self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) | |||
| self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) | |||
| self.fake_quant_train = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits, | |||
| ema=self.ema, | |||
| ema_decay=self.ema_decay, | |||
| @@ -113,25 +283,44 @@ class FakeQuantWithMinMax(Cell): | |||
| else: | |||
| min_array = np.array([min_init]).reshape(1).astype(np.float32) | |||
| max_array = np.array([max_init]).reshape(1).astype(np.float32) | |||
| self.fake_quant_train = P.FakeQuantWithMinMax(num_bits=self.num_bits, | |||
| ema=self.ema, | |||
| ema_decay=self.ema_decay, | |||
| quant_delay=self.quant_delay, | |||
| symmetric=self.symmetric, | |||
| narrow_range=self.narrow_range, | |||
| training=True) | |||
| self.fake_quant_infer = P.FakeQuantWithMinMax(num_bits=self.num_bits, | |||
| ema=self.ema, | |||
| ema_decay=self.ema_decay, | |||
| quant_delay=self.quant_delay, | |||
| symmetric=self.symmetric, | |||
| narrow_range=self.narrow_range, | |||
| training=False) | |||
| self.minq = Parameter( | |||
| Tensor(min_array), name='quant_min', requires_grad=False) | |||
| self.maxq = Parameter( | |||
| Tensor(max_array), name='quant_max', requires_grad=False) | |||
| self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) | |||
| self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) | |||
| if context.get_context('device_target') == "Ascend": | |||
| self.fake_quant_train = FakeQuantWithMinMaxD(num_bits=self.num_bits, | |||
| ema=self.ema, | |||
| ema_decay=self.ema_decay, | |||
| quant_delay=self.quant_delay, | |||
| symmetric=self.symmetric, | |||
| narrow_range=self.narrow_range, | |||
| training=True, | |||
| min_init=self.minq, | |||
| max_init=self.maxq) | |||
| self.fake_quant_infer = FakeQuantWithMinMaxD(num_bits=self.num_bits, | |||
| ema=self.ema, | |||
| ema_decay=self.ema_decay, | |||
| quant_delay=self.quant_delay, | |||
| symmetric=self.symmetric, | |||
| narrow_range=self.narrow_range, | |||
| training=False, | |||
| min_init=self.minq, | |||
| max_init=self.maxq) | |||
| elif context.get_context('device_target') == "GPU": | |||
| self.fake_quant_train = P.FakeQuantWithMinMax(num_bits=self.num_bits, | |||
| ema=self.ema, | |||
| ema_decay=self.ema_decay, | |||
| quant_delay=self.quant_delay, | |||
| symmetric=self.symmetric, | |||
| narrow_range=self.narrow_range, | |||
| training=True) | |||
| self.fake_quant_infer = P.FakeQuantWithMinMax(num_bits=self.num_bits, | |||
| ema=self.ema, | |||
| ema_decay=ema_decay, | |||
| quant_delay=quant_delay, | |||
| symmetric=self.symmetric, | |||
| narrow_range=self.narrow_range, | |||
| training=False) | |||
| else: | |||
| raise ValueError("Not support platform.") | |||
| def extend_repr(self): | |||
| s = 'min={}, max={}, ema={}, ema_decay={}, per_channel={}, quant_delay={}'.format( | |||
| @@ -146,6 +335,191 @@ class FakeQuantWithMinMax(Cell): | |||
| return out | |||
| class DepthwiseConv2dBatchNormQuant(Cell): | |||
| r""" | |||
| 2D depthwise convolution with BatchNormal op folded layer. | |||
| For a more Detailed overview of Conv2d op. | |||
| Args: | |||
| in_channels (int): The number of input channel :math:`C_{in}`. | |||
| out_channels (int): The number of output channel :math:`C_{out}`. | |||
| kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window. | |||
| stride (int): Specifies stride for all spatial dimensions with the same value. | |||
| pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". | |||
| padding: (int): Implicit paddings on both sides of the input. Default: 0. | |||
| eps (int): Parameters for BatchNormal. Default: 1e-5. | |||
| momentum (int): Parameters for BatchNormal op. Default: 0.9. | |||
| weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the | |||
| convolution kernel. Default: 'None'. | |||
| beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the | |||
| beta vector. Default: 'None'. | |||
| gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the | |||
| gamma vector. Default: 'None'. | |||
| mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the | |||
| mean vector. Default: 'None'. | |||
| var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the | |||
| variance vector. Default: 'None'. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000. | |||
| fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True. | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. | |||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||
| Inputs: | |||
| - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. | |||
| Outputs: | |||
| Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. | |||
| Examples: | |||
| >>> quant = nn.DepthwiseConv2dBatchNormQuant(1, 6, | |||
| kernel_size= (2, 2), | |||
| stride=(1, 1), | |||
| pad_mode="valid", | |||
| >>> dilation=(1, 1)) | |||
| >>> input_x = Tensor(np.random.randint(-2, 2, (2, 1, 1, 3)), mindspore.float32) | |||
| >>> result = quant(input_x) | |||
| """ | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride=1, | |||
| pad_mode='same', | |||
| padding=0, | |||
| dilation=1, | |||
| group=1, | |||
| eps=1e-5, | |||
| momentum=0.997, | |||
| weight_init=None, | |||
| beta_init=None, | |||
| gamma_init=None, | |||
| mean_init=None, | |||
| var_init=None, | |||
| quant_delay=0, | |||
| freeze_bn=100000, | |||
| fake=True, | |||
| num_bits=8, | |||
| per_channel=False, | |||
| symmetric=False, | |||
| narrow_range=False): | |||
| """init DepthwiseConv2dBatchNormQuant layer""" | |||
| super(DepthwiseConv2dBatchNormQuant, self).__init__() | |||
| self.in_channels = in_channels | |||
| self.out_channels = out_channels | |||
| self.pad_mode = pad_mode | |||
| self.padding = padding | |||
| self.dilation = twice(dilation) | |||
| self.stride = twice(stride) | |||
| self.group = group | |||
| self.fake = fake | |||
| self.freeze_bn = freeze_bn | |||
| self.momentum = momentum | |||
| self.quant_delay = quant_delay | |||
| if isinstance(kernel_size, int): | |||
| self.kernel_size = (kernel_size, kernel_size) | |||
| else: | |||
| self.kernel_size = kernel_size | |||
| if group > 1: | |||
| validator.check_integer('group', group, 'in_channels', in_channels, 'Conv2dBatchNormQuant') | |||
| validator.check_integer('group', group, 'in_channels', out_channels, 'Conv2dBatchNormQuant') | |||
| self.is_depthwise = group > 1 | |||
| channel_multiplier = out_channels // in_channels | |||
| self.conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, | |||
| kernel_size=kernel_size, | |||
| stride=stride, | |||
| pad_mode=pad_mode, | |||
| pad=padding) | |||
| if weight_init is None: | |||
| weight_init = initializer('normal', [channel_multiplier, in_channels, *kernel_size]) | |||
| self.weight = Parameter(weight_init, name='weight') | |||
| if gamma_init is None: | |||
| gamma_init = initializer('ones', [out_channels]) | |||
| self.gamma = Parameter(gamma_init, name='gamma') | |||
| if beta_init is None: | |||
| beta_init = initializer('zeros', [out_channels]) | |||
| self.beta = Parameter(beta_init, name='beta') | |||
| if mean_init is None: | |||
| mean_init = initializer('zeros', [out_channels]) | |||
| self.moving_mean = Parameter( | |||
| mean_init, name='moving_mean', requires_grad=False) | |||
| if var_init is None: | |||
| var_init = initializer('ones', [out_channels]) | |||
| self.moving_variance = Parameter( | |||
| var_init, name='moving_variance', requires_grad=False) | |||
| self.step = Parameter(initializer( | |||
| 'normal', [1], dtype=mstype.int32), name='step', requires_grad=False) | |||
| self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, | |||
| max_init=6, | |||
| ema=False, | |||
| num_bits=num_bits, | |||
| quant_delay=quant_delay, | |||
| per_channel=per_channel, | |||
| out_channels=out_channels, | |||
| symmetric=symmetric, | |||
| narrow_range=narrow_range) | |||
| self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn) | |||
| self.correct_mul = P.CorrectionMul(self.is_depthwise) | |||
| if context.get_context('device_target') == "Ascend": | |||
| self.batchnorm_fold2_train = P.BatchNormFold2_D(freeze_bn=freeze_bn) | |||
| self.batchnorm_fold2_infer = P.BatchNormFold2_D(freeze_bn=0) | |||
| elif context.get_context('device_target') == "GPU": | |||
| self.batchnorm_fold2_train = P.BatchNormFold2(freeze_bn=freeze_bn) | |||
| self.batchnorm_fold2_infer = P.BatchNormFold2(freeze_bn=0) | |||
| else: | |||
| raise ValueError("Not support platform.") | |||
| self.one = Tensor(1, mstype.int32) | |||
| self.assignadd = P.AssignAdd() | |||
| self.is_gpu = context.get_context('device_target') == "GPU" | |||
| def extend_repr(self): | |||
| s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \ | |||
| 'pad_mode={}, padding={}, dilation={}, group={}, ' \ | |||
| 'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format( | |||
| self.in_channels, self.out_channels, self.kernel_size, self.stride, | |||
| self.pad_mode, self.padding, self.dilation, self.group, | |||
| self.fake, self.freeze_bn, self.momentum, self.quant_delay) | |||
| return s | |||
| def construct(self, x): | |||
| out_conv = self.conv(x, self.weight) | |||
| # BN fold1 | |||
| batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold(out_conv, | |||
| self.moving_mean, | |||
| self.moving_variance, | |||
| self.step) | |||
| # fake weight | |||
| weight = self.correct_mul(self.weight, self.gamma, running_std) | |||
| if self.fake: | |||
| weight = self.fake_quant_weight(weight) | |||
| out = self.conv(x, weight) | |||
| # BN fold2 | |||
| if self.is_gpu: | |||
| if self.training: | |||
| out = self.batchnorm_fold2_train(out, self.beta, self.gamma, | |||
| batch_std, batch_mean, running_std, running_mean, self.step) | |||
| F.control_depend(out, self.assignadd(self.step, self.one)) | |||
| else: | |||
| out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, | |||
| batch_std, batch_mean, running_std, running_mean, self.step) | |||
| else: | |||
| if self.training: | |||
| out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std) | |||
| F.control_depend(out, self.assignadd(self.step, self.one)) | |||
| else: | |||
| out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, batch_std, batch_mean, running_std) | |||
| return out | |||
| class Conv2dBatchNormQuant(Cell): | |||
| r""" | |||
| 2D convolution with BatchNormal op folded layer. | |||
| @@ -215,6 +589,7 @@ class Conv2dBatchNormQuant(Cell): | |||
| per_channel=False, | |||
| symmetric=False, | |||
| narrow_range=False): | |||
| """init Conv2dBatchNormQuant layer""" | |||
| super(Conv2dBatchNormQuant, self).__init__() | |||
| self.in_channels = in_channels | |||
| self.out_channels = out_channels | |||
| @@ -231,7 +606,6 @@ class Conv2dBatchNormQuant(Cell): | |||
| self.kernel_size = (kernel_size, kernel_size) | |||
| else: | |||
| self.kernel_size = kernel_size | |||
| if weight_init is None: | |||
| weight_init = initializer( | |||
| 'normal', [out_channels, in_channels // group, *self.kernel_size]) | |||
| @@ -254,14 +628,6 @@ class Conv2dBatchNormQuant(Cell): | |||
| self.step = Parameter(initializer( | |||
| 'normal', [1], dtype=mstype.int32), name='step', requires_grad=False) | |||
| self.conv = P.Conv2D(out_channel=self.out_channels, | |||
| kernel_size=self.kernel_size, | |||
| mode=1, | |||
| pad_mode=self.pad_mode, | |||
| pad=self.padding, | |||
| stride=self.stride, | |||
| dilation=self.dilation, | |||
| group=self.group) | |||
| self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, | |||
| max_init=6, | |||
| ema=False, | |||
| @@ -271,23 +637,29 @@ class Conv2dBatchNormQuant(Cell): | |||
| out_channels=out_channels, | |||
| symmetric=symmetric, | |||
| narrow_range=narrow_range) | |||
| self.batchnorm_fold_train = P.BatchNormFold(epsilon=eps, | |||
| momentum=momentum, | |||
| is_training=True, | |||
| freeze_bn=freeze_bn) | |||
| self.batchnorm_fold_infer = P.BatchNormFold(epsilon=eps, | |||
| momentum=momentum, | |||
| is_training=False, | |||
| freeze_bn=freeze_bn) | |||
| self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn) | |||
| self.conv = P.Conv2D(out_channel=out_channels, | |||
| kernel_size=kernel_size, | |||
| mode=1, | |||
| pad_mode=pad_mode, | |||
| pad=padding, | |||
| stride=stride, | |||
| dilation=1, | |||
| group=group) | |||
| self.correct_mul = P.CorrectionMul() | |||
| self.relu = P.ReLU() | |||
| self.batchnorm_fold2 = P.BatchNormFold2(freeze_bn=freeze_bn) | |||
| self.batchnorm_fold2_infer = P.BatchNormFold2(freeze_bn=0) | |||
| if context.get_context('device_target') == "Ascend": | |||
| self.batchnorm_fold2_train = P.BatchNormFold2_D(freeze_bn=freeze_bn) | |||
| self.batchnorm_fold2_infer = P.BatchNormFold2_D(freeze_bn=0) | |||
| elif context.get_context('device_target') == "GPU": | |||
| self.batchnorm_fold2_train = P.BatchNormFold2(freeze_bn=freeze_bn) | |||
| self.batchnorm_fold2_infer = P.BatchNormFold2(freeze_bn=0) | |||
| else: | |||
| raise ValueError("Not support platform.") | |||
| self.one = Tensor(1, mstype.int32) | |||
| self.assignadd = P.AssignAdd() | |||
| def extend_repr(self): | |||
| s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \ | |||
| s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \ | |||
| 'pad_mode={}, padding={}, dilation={}, group={}, ' \ | |||
| 'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format( | |||
| self.in_channels, self.out_channels, self.kernel_size, self.stride, | |||
| @@ -296,34 +668,32 @@ class Conv2dBatchNormQuant(Cell): | |||
| return s | |||
| def construct(self, x): | |||
| if self.training: | |||
| beta = self.beta | |||
| gamma = self.gamma | |||
| gmean = self.moving_mean | |||
| gvar = self.moving_variance | |||
| step = self.step | |||
| out_conv = self.conv(x, self.weight) | |||
| batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold_train( | |||
| out_conv, gmean, gvar, step) | |||
| # BN fold1 | |||
| weight = self.correct_mul(self.weight, gamma, running_std) | |||
| if self.fake: | |||
| weight = self.fake_quant_weight(weight) | |||
| out = self.conv(x, weight) | |||
| # BN fold2 | |||
| out = self.batchnorm_fold2( | |||
| out, beta, gamma, batch_std, batch_mean, running_std, running_mean, step) | |||
| F.control_depend(out, self.assignadd(self.step, self.one)) | |||
| out_conv = self.conv(x, self.weight) | |||
| # BN fold1 | |||
| batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold(out_conv, | |||
| self.moving_mean, | |||
| self.moving_variance, | |||
| self.step) | |||
| # fake weight | |||
| weight = self.correct_mul(self.weight, self.gamma, running_std) | |||
| if self.fake: | |||
| weight = self.fake_quant_weight(weight) | |||
| out = self.conv(x, weight) | |||
| # BN fold2 | |||
| if self.is_gpu: | |||
| if self.training: | |||
| out = self.batchnorm_fold2_train(out, self.beta, self.gamma, | |||
| batch_std, batch_mean, running_std, running_mean, self.step) | |||
| F.control_depend(out, self.assignadd(self.step, self.one)) | |||
| else: | |||
| out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, | |||
| batch_std, batch_mean, running_std, running_mean, self.step) | |||
| else: | |||
| step = self.step | |||
| batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold_infer( | |||
| x, self.moving_mean, self.moving_variance, step) | |||
| weight = self.correct_mul(self.weight, self.gamma, running_std) | |||
| if self.fake: | |||
| weight = self.fake_quant_weight(weight) | |||
| out = self.conv(x, weight) | |||
| out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, batch_std, batch_mean, | |||
| running_std, running_mean, step) | |||
| if self.training: | |||
| out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std) | |||
| F.control_depend(out, self.assignadd(self.step, self.one)) | |||
| else: | |||
| out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, batch_std, batch_mean, running_std) | |||
| return out | |||
| @@ -434,7 +804,7 @@ class Conv2dQuant(Cell): | |||
| return out | |||
| def extend_repr(self): | |||
| s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \ | |||
| s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \ | |||
| 'pad_mode={}, padding={}, dilation={}, group={}, ' \ | |||
| 'has_bias={}, quant_delay={}'.format( | |||
| self.in_channels, self.out_channels, self.kernel_size, self.stride, | |||
| @@ -22,7 +22,7 @@ from ..composite.multitype_ops.zeros_like_impl import zeros_like | |||
| @bprop_getters.register(P.FakeQuantWithMinMax) | |||
| def get_bprop_fakequant_with_minmax(self): | |||
| """Generate bprop for FakeQuantWithMinMax""" | |||
| """Generate bprop for FakeQuantWithMinMax for GPU and Ascend""" | |||
| op = P.FakeQuantWithMinMaxGrad(num_bits=self.num_bits, quant_delay=self.quant_delay) | |||
| def bprop(x, x_min, x_max, out, dout): | |||
| @@ -34,7 +34,7 @@ def get_bprop_fakequant_with_minmax(self): | |||
| @bprop_getters.register(P.FakeQuantWithMinMaxPerChannel) | |||
| def get_bprop_fakequant_with_minmax_perchannel(self): | |||
| """Generate bprop for FakeQuantWithMinMaxPerChannel""" | |||
| """Generate bprop for FakeQuantWithMinMaxPerChannel for GPU""" | |||
| op = P.FakeQuantWithMinMaxPerChannelGrad(num_bits=self.num_bits, quant_delay=self.quant_delay) | |||
| def bprop(x, x_min, x_max, out, dout): | |||
| @@ -46,7 +46,7 @@ def get_bprop_fakequant_with_minmax_perchannel(self): | |||
| @bprop_getters.register(P.BatchNormFold) | |||
| def get_bprop_batchnorm_fold(self): | |||
| """Generate bprop for BatchNormFold""" | |||
| """Generate bprop for BatchNormFold for GPU""" | |||
| op = P.BatchNormFoldGrad(self.epsilon, self.is_training, self.freeze_bn) | |||
| def bprop(x, mean, variance, global_step, out, dout): | |||
| @@ -58,8 +58,8 @@ def get_bprop_batchnorm_fold(self): | |||
| @bprop_getters.register(P.CorrectionMul) | |||
| def get_bprop_correction_mul(self): | |||
| """Generate bprop for CorrectionMul""" | |||
| grad = P.CorrectionMulGrad() | |||
| """Generate bprop for CorrectionMul for Ascend and GPU""" | |||
| grad = P.CorrectionMulGrad(self.channel_axis) | |||
| def bprop(x, batch_std, running_std, out, dout): | |||
| dx, d_batch_std = grad(dout, x, batch_std, running_std) | |||
| @@ -70,7 +70,7 @@ def get_bprop_correction_mul(self): | |||
| @bprop_getters.register(P.BatchNormFold2) | |||
| def get_bprop_batchnorm_fold2(self): | |||
| """Generate bprop for CorrectionAdd""" | |||
| """Generate bprop for BatchNormFold2 for GPU""" | |||
| op_f = P.BatchNormFold2Grad(freeze_bn=self.freeze_bn) | |||
| def bprop(x, beta, gamma, batch_std, batch_mean, running_std, running_mean, global_step, out, dout): | |||
| @@ -80,3 +80,48 @@ def get_bprop_batchnorm_fold2(self): | |||
| zeros_like(global_step) | |||
| return bprop | |||
| @bprop_getters.register(P.BatchNormFoldD) | |||
| def get_bprop_BatchNormFold(self): | |||
| """Generate bprop for BatchNormFold for Ascend""" | |||
| op = P.BatchNormFoldGrad_(self.epsilon, self.is_training, self.freeze_bn) | |||
| def bprop(x, x_sum, x_square_sum, mean, variance, out, dout): | |||
| dx = op(dout[1], dout[2], x, out[1], out[2]) | |||
| return dx, zeros_like(x_sum), zeros_like(x_square_sum), zeros_like(mean), zeros_like(variance) | |||
| return bprop | |||
| @bprop_getters.register(P.BNTrainingReduce) | |||
| def get_bprop_BNTrainingReduce(self): | |||
| def bprop(x, out, dout): | |||
| return (zeros_like(x),) | |||
| return bprop | |||
| @bprop_getters.register(P.BatchNormFold2_D) | |||
| def get_bprop_batchnorm_fold2_(self): | |||
| """Generate bprop for BatchNormFold2 for Ascend""" | |||
| op_reduce = P.BatchNormFold2GradReduce(freeze_bn=self.freeze_bn) | |||
| op_f = P.BatchNormFold2GradD(freeze_bn=self.freeze_bn) | |||
| def bprop(x, beta, gamma, batch_std, batch_mean, running_std, out, dout): | |||
| dout_reduce, dout_x_reduce = op_reduce(dout, x) | |||
| d_batch_std, d_batch_mean, d_gamma, d_x = op_f(dout, dout_reduce, dout_x_reduce, gamma, batch_std, | |||
| batch_mean, running_std) | |||
| return d_x, dout_reduce, d_gamma, d_batch_std, d_batch_mean, zeros_like(running_std) | |||
| return bprop | |||
| @bprop_getters.register(P.FakeQuantWithMinMaxUpdate) | |||
| def get_bprop_fakequant_with_minmax_update(self): | |||
| """Generate bprop for FakeQuantWithMinMaxUpdate for Ascend""" | |||
| def bprop(x, x_min, x_max, out, dout): | |||
| return zeros_like(x), zeros_like(x_min), zeros_like(x_max) | |||
| return bprop | |||
| @@ -0,0 +1,149 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """_BatchNormFold op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| from te import tvm | |||
| from topi import generic | |||
| from topi.cce import util | |||
| batch_norm_op_info = TBERegOp("BatchNormFoldD") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("batchnorm_fold.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("batchnorm_fold") \ | |||
| .partial_flag(True) \ | |||
| .attr("momentum", "optional", "float", "all") \ | |||
| .attr("epsilon", "optional", "float", "all") \ | |||
| .attr("is_training", "optional", "bool", "all") \ | |||
| .attr("freeze_bn", "optional", "int", "all") \ | |||
| .attr("data_format", "optional", "str", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "x_sum", False, "required", "all") \ | |||
| .input(2, "x_square_sum", False, "required", "all") \ | |||
| .input(3, "mean", False, "required", "all") \ | |||
| .input(4, "variance", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .output(1, "batch_mean", False, "required", "all") \ | |||
| .output(2, "batch_std", False, "required", "all") \ | |||
| .output(3, "running_mean", False, "required", "all") \ | |||
| .output(4, "running_std", False, "required", "all") \ | |||
| .output(5, "mean_updated", False, "required", "all") \ | |||
| .output(6, "variance_updated", False, "required", "all") \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register(batch_norm_op_info) | |||
| def _batchnorm_fold_tbe(): | |||
| """_BatchNormFold TBE register""" | |||
| return | |||
| @util.check_input_type(dict, dict, dict, dict, dict, | |||
| dict, dict, dict, dict, dict, dict, dict, | |||
| float, float, bool, int, str, str) | |||
| def batchnorm_fold(x, x_sum, x_square_sum, mean, variance, | |||
| y, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated, | |||
| momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0, data_format="NCHW", | |||
| kernel_name="batchnorm_fold"): | |||
| """batchnorm_fold TBE op""" | |||
| momentum = 1.0 - momentum | |||
| util.check_kernel_name(kernel_name) | |||
| data_format = data_format.upper() | |||
| if data_format != "NCHW": | |||
| raise RuntimeError("The data_format only support NCHW") | |||
| shape_x = x.get("shape") | |||
| shape_mean = mean.get("shape") | |||
| shape_variance = variance.get("shape") | |||
| dtype_x = x.get("dtype") | |||
| dtype_mean = mean.get("dtype") | |||
| dtype_variance = variance.get("dtype") | |||
| for shape in (shape_x, shape_mean, shape_variance): | |||
| util.check_shape_rule(shape) | |||
| util.check_tensor_shape_size(shape) | |||
| check_tuple = ("float16", "float32") | |||
| for dtype in (dtype_x, dtype_mean, dtype_variance): | |||
| util.check_dtype_rule(dtype.lower(), check_tuple) | |||
| format_data = x.get("format").upper() | |||
| if format_data not in ("NCHW", "NC1HWC0"): | |||
| raise RuntimeError("Format of input only support 4D and 5HD") | |||
| if format_data == "NC1HWC0": | |||
| if len(shape_x) != 5: | |||
| raise RuntimeError("batchnorm_fold only support shape 5D" | |||
| "when input format is NC1HWC0") | |||
| shape_mean = (1, shape_x[1], 1, 1, shape_x[4]) | |||
| elif format_data == "NCHW": | |||
| if len(shape_x) < 2 or len(shape_x) > 4: | |||
| raise RuntimeError("batchnorm_fold only support shape 2D to 4D") | |||
| if shape_x[1] != shape_mean[0]: | |||
| raise RuntimeError("data_format is NCHW, shape_bias must" | |||
| "be equal to the second axis of shape_x") | |||
| shape_mean = (1, shape_x[1],) | |||
| for _ in range(2, len(shape_x)): | |||
| shape_mean = shape_mean + (1,) | |||
| x_input = tvm.placeholder(shape_x, name="x_input", dtype=dtype_x.lower()) | |||
| x_sum = tvm.placeholder(shape_mean, name="x_sum", dtype=dtype_x.lower()) | |||
| x_square_sum = tvm.placeholder(shape_mean, name="x_square_sum", dtype=dtype_x.lower()) | |||
| mean = tvm.placeholder(shape_mean, name="mean", dtype=dtype_mean.lower()) | |||
| variance = tvm.placeholder(shape_mean, name="variance", dtype=dtype_variance.lower()) | |||
| shape_x = te.lang.cce.util.shape_to_list(x_input.shape) | |||
| num = shape_x[0] * shape_x[2] * shape_x[3] | |||
| num_rec = 1.0 / num | |||
| # compute the mean of x | |||
| batch_mean = te.lang.cce.vmuls(x_sum, num_rec) | |||
| # compute the variance of x | |||
| variance_div = te.lang.cce.vmuls(x_square_sum, num_rec) | |||
| mean_square = te.lang.cce.vmul(batch_mean, batch_mean) | |||
| batch_var_biased = te.lang.cce.vsub(variance_div, mean_square) | |||
| if num == 1: | |||
| batch_var_scaler = 0.0 | |||
| else: | |||
| batch_var_scaler = float(num) / (num - 1) | |||
| batch_variance = te.lang.cce.vmuls(batch_var_biased, batch_var_scaler) | |||
| batch_std = te.lang.cce.vsqrt(te.lang.cce.vadds(batch_variance, epsilon)) | |||
| factor = 1.0 - momentum | |||
| factor_reverse = momentum | |||
| mean_mul = te.lang.cce.vmuls(batch_mean, factor) | |||
| mean_mul_rev = te.lang.cce.vmuls(mean, factor_reverse) | |||
| mean_updated = te.lang.cce.vadd(mean_mul, mean_mul_rev) | |||
| var_mul = te.lang.cce.vmuls(batch_variance, factor) | |||
| var_mul_rev = te.lang.cce.vmuls(variance, factor_reverse) | |||
| variance_updated = te.lang.cce.vadd(var_mul, var_mul_rev) | |||
| y = te.lang.cce.vadds(x_input, 0.0) | |||
| running_mean = te.lang.cce.vadds(mean, 0.0) | |||
| running_std = te.lang.cce.vsqrt(te.lang.cce.vadds(variance, epsilon)) | |||
| res = [y, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated] | |||
| with tvm.target.cce(): | |||
| sch = generic.auto_schedule(res) | |||
| config = {"name": kernel_name, | |||
| "tensor_list": [x_input, x_sum, x_square_sum, mean, variance] + res} | |||
| te.lang.cce.cce_build_code(sch, config) | |||
| @@ -0,0 +1,110 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """_BatchNormFold2 op""" | |||
| import te.lang.cce | |||
| from te import tvm | |||
| from te.platform.fusion_manager import fusion_manager | |||
| from topi import generic | |||
| from topi.cce import util | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| SHAPE_SIZE_LIMIT = 2147483648 | |||
| batchnorm_fold2_op_info = TBERegOp("BatchNormFold2_D") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("batchnorm_fold2.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("batchnorm_fold2") \ | |||
| .partial_flag(True) \ | |||
| .op_pattern("formatAgnostic") \ | |||
| .input(0, "x", None, "required", None) \ | |||
| .input(1, "beta", None, "required", None) \ | |||
| .input(2, "gamma", None, "required", None) \ | |||
| .input(3, "batch_std", None, "required", None) \ | |||
| .input(4, "batch_mean", None, "required", None) \ | |||
| .input(5, "running_std", None, "required", None) \ | |||
| .output(0, "y", True, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, | |||
| DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register(batchnorm_fold2_op_info) | |||
| def _batchnorm_fold2_tbe(): | |||
| """_BatchNormFold2 TBE register""" | |||
| return | |||
| @fusion_manager.register("batchnorm_fold2") | |||
| def batchnorm_fold2_compute(x, beta, gamma, batch_std, batch_mean, running_std, kernel_name="batchnorm_fold2"): | |||
| """_BatchNormFold2 compute""" | |||
| shape_x = te.lang.cce.util.shape_to_list(x.shape) | |||
| factor = te.lang.cce.vdiv(running_std, batch_std) | |||
| factor_b = te.lang.cce.broadcast(factor, shape_x) | |||
| res = te.lang.cce.vmul(x, factor_b) | |||
| bias = te.lang.cce.vdiv(batch_mean, batch_std) | |||
| bias = te.lang.cce.vmul(bias, gamma) | |||
| bias = te.lang.cce.vsub(beta, bias) | |||
| bias_b = te.lang.cce.broadcast(bias, shape_x) | |||
| res = te.lang.cce.vadd(res, bias_b) | |||
| return res | |||
| @util.check_input_type(dict, dict, dict, dict, dict, dict, dict, str) | |||
| def batchnorm_fold2(x, beta, gamma, batch_std, batch_mean, running_std, y, kernel_name="batchnorm_fold2"): | |||
| """_BatchNormFold2 op""" | |||
| shape = x.get("shape") | |||
| util.check_kernel_name(kernel_name) | |||
| util.check_shape_rule(shape) | |||
| util.check_shape_size(shape, SHAPE_SIZE_LIMIT) | |||
| check_list = ["float16", "float32"] | |||
| inp_dtype = x.get("dtype").lower() | |||
| if not inp_dtype in check_list: | |||
| raise RuntimeError("Dtype of input only support float16, float32") | |||
| data_format = x.get("format") | |||
| ori_format = x.get("ori_format") | |||
| if data_format.upper() not in ("NC1HWC0", "NCHW"): | |||
| raise RuntimeError("Un supported data format {}".format(data_format)) | |||
| if data_format.upper() == "NCHW" and ori_format != "NCHW": | |||
| raise RuntimeError("data_format(NCHW) must same as ori_format") | |||
| shape_c = gamma.get("shape") | |||
| if gamma.get("format").upper() == "NCHW": | |||
| shape_c = 1, gamma.get("shape")[0], 1, 1 | |||
| x_t = tvm.placeholder(shape, name="x", dtype=inp_dtype) | |||
| beta_t = tvm.placeholder(shape_c, name="beta", dtype=inp_dtype) | |||
| gamma_t = tvm.placeholder(shape_c, name="gamma", dtype=inp_dtype) | |||
| batch_std_t = tvm.placeholder(shape_c, name="batch_std", dtype=inp_dtype) | |||
| batch_mean_t = tvm.placeholder(shape_c, name="batch_mean", dtype=inp_dtype) | |||
| running_std_t = tvm.placeholder(shape_c, name="running_std", dtype=inp_dtype) | |||
| res = batchnorm_fold2_compute(x_t, beta_t, gamma_t, batch_std_t, batch_mean_t, | |||
| running_std_t, kernel_name) | |||
| with tvm.target.cce(): | |||
| sch = generic.auto_schedule(res) | |||
| config = {"print_ir": False, | |||
| "name": kernel_name, | |||
| "tensor_list": [x_t, beta_t, gamma_t, batch_std_t, batch_mean_t, running_std_t, res]} | |||
| te.lang.cce.cce_build_code(sch, config) | |||
| @@ -0,0 +1,126 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """_BatchNormFold2Grad op""" | |||
| import te.lang.cce | |||
| from te import tvm | |||
| from te.platform.fusion_manager import fusion_manager | |||
| from topi import generic | |||
| from topi.cce import util | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| SHAPE_SIZE_LIMIT = 2147483648 | |||
| batchnorm_fold2_grad_op_info = TBERegOp("BatchNormFold2GradD") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("batchnorm_fold2_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("batchnorm_fold2_grad") \ | |||
| .partial_flag(True) \ | |||
| .op_pattern("formatAgnostic") \ | |||
| .input(0, "dout", None, "required", None) \ | |||
| .input(1, "dout_reduce", None, "required", None) \ | |||
| .input(2, "dout_x_reduce", None, "required", None) \ | |||
| .input(3, "gamma", None, "required", None) \ | |||
| .input(4, "batch_std", None, "required", None) \ | |||
| .input(5, "batch_mean", None, "required", None) \ | |||
| .input(6, "running_std", None, "required", None) \ | |||
| .output(0, "d_batch_std", True, "required", "all") \ | |||
| .output(1, "d_batch_mean", True, "required", "all") \ | |||
| .output(2, "d_gamma", True, "required", "all") \ | |||
| .output(3, "dx", True, "required", "all") \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register(batchnorm_fold2_grad_op_info) | |||
| def _batchnorm_fold2_grad_tbe(): | |||
| """_BatchNormFold2Grad TBE register""" | |||
| return | |||
| @fusion_manager.register("batchnorm_fold2_grad") | |||
| def batchnorm_fold2_grad_compute(dout, dout_reduce, dout_x_reduce, gamma, batch_std, batch_mean, running_std, | |||
| kernel_name="batchnorm_fold2_grad"): | |||
| """_BatchNormFold2Grad""" | |||
| shape_x = te.lang.cce.util.shape_to_list(dout.shape) | |||
| d_batch_std_1 = te.lang.cce.vmul(dout_reduce, batch_mean) | |||
| d_batch_std_1 = te.lang.cce.vmul(d_batch_std_1, gamma) | |||
| d_batch_std_2 = te.lang.cce.vmul(dout_x_reduce, running_std) | |||
| d_batch_std = te.lang.cce.vsub(d_batch_std_1, d_batch_std_2) | |||
| d_batch_std = te.lang.cce.vdiv(d_batch_std, batch_std) | |||
| d_batch_std = te.lang.cce.vdiv(d_batch_std, batch_std) | |||
| d_batch_mean = te.lang.cce.vmul(dout_reduce, gamma) | |||
| d_batch_mean = te.lang.cce.vdiv(d_batch_mean, batch_std) | |||
| d_batch_mean = te.lang.cce.vmuls(d_batch_mean, -1.) | |||
| d_gamma = te.lang.cce.vmul(dout_reduce, batch_mean) | |||
| d_gamma = te.lang.cce.vdiv(d_gamma, batch_std) | |||
| d_gamma = te.lang.cce.vmuls(d_gamma, -1.) | |||
| dx = te.lang.cce.vdiv(running_std, batch_std) | |||
| dx = te.lang.cce.broadcast(dx, shape_x) | |||
| dx = te.lang.cce.vmul(dx, dout) | |||
| return [d_batch_std, d_batch_mean, d_gamma, dx] | |||
| @util.check_input_type(dict, dict, dict, dict, dict, dict, dict, dict, dict, dict, dict, str) | |||
| def batchnorm_fold2_grad(dout, dout_reduce, dout_x_reduce, gamma, batch_std, batch_mean, running_std, d_batch_std, | |||
| d_batch_mean, d_gamma, dx, kernel_name="batchnorm_fold2_grad"): | |||
| """_BatchNormFold2Grad op """ | |||
| shape = dout.get("shape") | |||
| util.check_kernel_name(kernel_name) | |||
| util.check_shape_rule(shape) | |||
| util.check_shape_size(shape, SHAPE_SIZE_LIMIT) | |||
| check_list = ["float16", "float32"] | |||
| inp_dtype = dout.get("dtype").lower() | |||
| if not inp_dtype in check_list: | |||
| raise RuntimeError("Dtype of input only support float16, float32") | |||
| data_format = dout.get("format") | |||
| ori_format = dout.get("ori_format") | |||
| if data_format.upper() not in ("NC1HWC0", "NCHW"): | |||
| raise RuntimeError("Un supported data format {}".format(data_format)) | |||
| if data_format.upper() == "NCHW" and ori_format != "NCHW": | |||
| raise RuntimeError("data_format(NCHW) must same as ori_format") | |||
| shape_c = gamma.get("shape") | |||
| if gamma.get("format").upper() == "NCHW": | |||
| shape_c = 1, gamma.get("shape")[0], 1, 1 | |||
| dout_t = tvm.placeholder(shape, name="dout", dtype=inp_dtype) | |||
| dout_reduce_t = tvm.placeholder(shape_c, name="dout_reduce", dtype=inp_dtype) | |||
| dout_x_reduce_t = tvm.placeholder(shape_c, name="dout_x_reduce", dtype=inp_dtype) | |||
| gamma_t = tvm.placeholder(shape_c, name="gamma", dtype=inp_dtype) | |||
| batch_std_t = tvm.placeholder(shape_c, name="batch_std", dtype=inp_dtype) | |||
| batch_mean_t = tvm.placeholder(shape_c, name="batch_mean", dtype=inp_dtype) | |||
| running_std_t = tvm.placeholder(shape_c, name="running_std", dtype=inp_dtype) | |||
| res_list = batchnorm_fold2_grad_compute(dout_t, dout_reduce_t, dout_x_reduce_t, gamma_t, batch_std_t, batch_mean_t, | |||
| running_std_t, kernel_name) | |||
| with tvm.target.cce(): | |||
| sch = generic.auto_schedule(res_list) | |||
| tensor_list = [dout_t, dout_reduce_t, dout_x_reduce_t, gamma_t, batch_std_t, batch_mean_t, running_std_t] + list( | |||
| res_list) | |||
| config = {"print_ir": False, | |||
| "name": kernel_name, | |||
| "tensor_list": tensor_list} | |||
| te.lang.cce.cce_build_code(sch, config) | |||
| @@ -0,0 +1,107 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """_BatchNormFold2GradReduce op""" | |||
| import te.lang.cce | |||
| from te import tvm | |||
| from te.platform.fusion_manager import fusion_manager | |||
| from te.platform.cce_build import build_config | |||
| from topi import generic | |||
| from topi.cce import util | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| SHAPE_SIZE_LIMIT = 2147483648 | |||
| batchnorm_fold2_grad_reduce_op_info = TBERegOp("BatchNormFold2GradReduce") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("batchnorm_fold2_grad_reduce.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("batchnorm_fold2_grad_reduce") \ | |||
| .partial_flag(True) \ | |||
| .op_pattern("formatAgnostic") \ | |||
| .input(0, "dout", None, "required", None) \ | |||
| .input(1, "x", None, "required", None) \ | |||
| .output(0, "dout_reduce", True, "required", "all") \ | |||
| .output(1, "dout_x_reduce", True, "required", "all") \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register(batchnorm_fold2_grad_reduce_op_info) | |||
| def _batchnorm_fold2_grad_reduce_tbe(): | |||
| """_BatchNormFold2GradReduce TBE register""" | |||
| return | |||
| @fusion_manager.register("batchnorm_fold2_grad_reduce") | |||
| def batchnorm_fold2_grad_reduce_compute(dout, x, dout_args, kernel_name="batchnorm_fold2_grad_reduce"): | |||
| """_BatchNormFold2GradReduce compute""" | |||
| dtype = dout_args.get("dtype") | |||
| dout_format = dout_args.get("format") | |||
| ori_format = dout_args.get("ori_format") | |||
| shape = dout_args.get("shape") | |||
| if dtype == "float16": | |||
| dout = te.lang.cce.cast_to(dout, "float32") | |||
| x = te.lang.cce.cast_to(x, "float32") | |||
| dout_x = te.lang.cce.vmul(dout, x) | |||
| if dout_format == "NC1HWC0": | |||
| axis = [0, 2, 3] | |||
| dout_reduce, dout_x_reduce = te.lang.cce.tuple_sum([dout, dout_x], axis, True) | |||
| else: | |||
| axis = list(range(len(shape))) | |||
| if ori_format == "NCHW": | |||
| axis.pop(1) | |||
| for _, i in enumerate(range(len(shape))): | |||
| if shape[i] == 1 and i in axis: | |||
| axis.remove(i) | |||
| dout_reduce = te.lang.cce.sum(dout, axis, False) | |||
| dout_x_reduce = te.lang.cce.sum(dout_x, axis, False) | |||
| return [dout_reduce, dout_x_reduce] | |||
| @util.check_input_type(dict, dict, dict, dict, str) | |||
| def batchnorm_fold2_grad_reduce(dout, x, dout_reduce, dout_x_reduce, kernel_name="batchnorm_fold2_grad_reduce"): | |||
| """_BatchNormFold2GradReduce op""" | |||
| shape = x.get("shape") | |||
| x_format = x.get("format") | |||
| util.check_kernel_name(kernel_name) | |||
| util.check_shape_rule(shape) | |||
| util.check_shape_size(shape, SHAPE_SIZE_LIMIT) | |||
| check_list = ["float16", "float32"] | |||
| inp_dtype = x.get("dtype").lower() | |||
| if not inp_dtype in check_list: | |||
| raise RuntimeError("Dtype of input only support float16, float32") | |||
| dout_t = tvm.placeholder(shape, name="dout", dtype=inp_dtype) | |||
| x_t = tvm.placeholder(shape, name="x", dtype=inp_dtype) | |||
| res_list = batchnorm_fold2_grad_reduce_compute(dout_t, x_t, dout, kernel_name) | |||
| if x_format == "NC1HWC0": | |||
| with tvm.target.cce(): | |||
| sch = generic.auto_schedule(res_list) | |||
| tensor_list = [dout_t, x_t] + list(res_list) | |||
| config = {"print_ir": False, | |||
| "name": kernel_name, | |||
| "tensor_list": tensor_list} | |||
| te.lang.cce.cce_build_code(sch, config) | |||
| return | |||
| from impl.bn_training_reduce import bn_training_reduce_schedule_nd | |||
| sch, tensor_list = bn_training_reduce_schedule_nd(res_list) | |||
| with build_config: | |||
| tvm.build(sch, tensor_list, "cce", name=kernel_name) | |||
| @@ -0,0 +1,124 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """_BatchNormFoldGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| import te.lang.cce | |||
| from te import tvm | |||
| from topi import generic | |||
| from topi.cce import util | |||
| batch_norm_op_info = TBERegOp("BatchNormFoldGradD") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("batchnorm_fold_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("batchnorm_fold_grad") \ | |||
| .partial_flag(True) \ | |||
| .attr("epsilon", "optional", "float", "all") \ | |||
| .attr("is_training", "optional", "bool", "all") \ | |||
| .attr("freeze_bn", "optional", "int", "all") \ | |||
| .input(0, "d_batch_mean", False, "required", "all") \ | |||
| .input(1, "d_batch_std", False, "required", "all") \ | |||
| .input(2, "x", False, "required", "all") \ | |||
| .input(3, "batch_mean", False, "required", "all") \ | |||
| .input(4, "batch_std", False, "required", "all") \ | |||
| .output(0, "dx", False, "required", "all") \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F16_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register(batch_norm_op_info) | |||
| def _batchnorm_fold_grad_tbe(): | |||
| """_BatchNormFoldGrad TBE register""" | |||
| return | |||
| def _batchnorm_fold_grad_compute(d_batch_mean, d_batch_std, data_x, batch_mean, batch_std): | |||
| """_batchnorm_fold_grad_compute """ | |||
| shape_x = te.lang.cce.util.shape_to_list(data_x.shape) | |||
| normal_size = shape_x[0] * shape_x[2] * shape_x[3] | |||
| d_batch_mean_broad = te.lang.cce.broadcast(d_batch_mean, shape_x) | |||
| d_batch_std_broad = te.lang.cce.broadcast(d_batch_std, shape_x) | |||
| batch_mean_broad = te.lang.cce.broadcast(batch_mean, shape_x) | |||
| batch_std_broad = te.lang.cce.broadcast(batch_std, shape_x) | |||
| dx = te.lang.cce.vsub(data_x, batch_mean_broad) | |||
| dx = te.lang.cce.vmul(dx, d_batch_std_broad) | |||
| dx = te.lang.cce.vdiv(dx, batch_std_broad) | |||
| dx = te.lang.cce.vadd(dx, d_batch_mean_broad) | |||
| dx = te.lang.cce.vmuls(dx, tvm.const(1. / normal_size, dtype=dx.dtype)) | |||
| return [dx] | |||
| @util.check_input_type(dict, dict, dict, dict, dict, dict, | |||
| float, bool, int, str) | |||
| def batchnorm_fold_grad(d_batch_mean, d_batch_std, x, batch_mean, batch_std, dx, | |||
| epsilon=1e-5, is_training=True, freeze_bn=0, kernel_name="batchnorm_fold_grad"): | |||
| """batchnorm_fold_grad op """ | |||
| util.check_kernel_name(kernel_name) | |||
| for iv in (d_batch_mean, d_batch_std, x, batch_mean, batch_std): | |||
| util.check_shape_rule(iv.get("shape")) | |||
| util.check_tensor_shape_size(iv.get("shape")) | |||
| check_tuple = ("float16", "float32") | |||
| for iv in (d_batch_mean, d_batch_std, x, batch_mean, batch_std): | |||
| util.check_dtype_rule(iv.get("dtype").lower(), check_tuple) | |||
| shape_x = x.get("shape") | |||
| dtype_x = x.get("dtype") | |||
| format_data = x.get("format").upper() | |||
| if format_data not in ("NCHW", "NC1HWC0"): | |||
| raise RuntimeError("Format of input only support 4D and 5HD") | |||
| shape_mean = d_batch_mean.get("shape") | |||
| dtype_mean = d_batch_mean.get("dtype").lower() | |||
| if format_data == "NC1HWC0": | |||
| if len(shape_x) != 5: | |||
| raise RuntimeError("batchnorm_fold only support shape 5D" | |||
| "when input format is NC1HWC0") | |||
| shape_mean = (1, shape_x[1], 1, 1, shape_x[4]) | |||
| elif format_data == "NCHW": | |||
| if len(shape_x) < 2 or len(shape_x) > 4: | |||
| raise RuntimeError("batchnorm_fold only support shape 2D to 4D") | |||
| if shape_x[1] != shape_mean[0]: | |||
| raise RuntimeError("data_format is NCHW, shape_bias must" | |||
| "be equal to the second axis of shape_x") | |||
| shape_mean = (1, shape_x[1],) | |||
| for _ in range(2, len(shape_x)): | |||
| shape_mean = shape_mean + (1,) | |||
| d_batch_mean = tvm.placeholder(shape_mean, name="d_batch_mean", dtype=dtype_mean) | |||
| d_batch_std = tvm.placeholder(shape_mean, name="d_batch_std", dtype=dtype_mean) | |||
| data_x = tvm.placeholder(shape_x, name="data_x", dtype=dtype_x.lower()) | |||
| batch_mean = tvm.placeholder(shape_mean, name="batch_mean", dtype=dtype_mean) | |||
| batch_std = tvm.placeholder(shape_mean, name="batch_std", dtype=dtype_mean) | |||
| res = _batchnorm_fold_grad_compute(d_batch_mean, d_batch_std, data_x, batch_mean, batch_std) | |||
| with tvm.target.cce(): | |||
| sch = generic.auto_schedule(res) | |||
| tensor_list = [d_batch_mean, d_batch_std, data_x, batch_mean, batch_std] + res | |||
| config = {"name": kernel_name, | |||
| "tensor_list": tensor_list} | |||
| te.lang.cce.cce_build_code(sch, config) | |||
| @@ -0,0 +1,92 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """CorrectionMul op""" | |||
| import te.lang.cce | |||
| from te import tvm | |||
| from te.platform.fusion_manager import fusion_manager | |||
| from topi import generic | |||
| from topi.cce import util | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| SHAPE_SIZE_LIMIT = 2147483648 | |||
| correction_mul_op_info = TBERegOp("CorrectionMul") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("correction_mul.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("correction_mul") \ | |||
| .partial_flag(True) \ | |||
| .op_pattern("formatAgnostic") \ | |||
| .attr("channel_axis", "optional", "int", "all") \ | |||
| .input(0, "x", None, "required", None) \ | |||
| .input(1, "batch_std", None, "required", None) \ | |||
| .input(2, "running_std", None, "required", None) \ | |||
| .output(0, "y", True, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register(correction_mul_op_info) | |||
| def _correction_mul_tbe(): | |||
| """CorrectionMul TBE register""" | |||
| return | |||
| @fusion_manager.register("correction_mul") | |||
| def correction_mul_compute(x, batch_std, running_std, kernel_name="correction_mul"): | |||
| """CorrectionMul compute""" | |||
| shape_x = te.lang.cce.util.shape_to_list(x.shape) | |||
| factor = te.lang.cce.vdiv(batch_std, running_std) | |||
| factor_b = te.lang.cce.broadcast(factor, shape_x) | |||
| res = te.lang.cce.vmul(x, factor_b) | |||
| return res | |||
| @util.check_input_type(dict, dict, dict, dict, int, str) | |||
| def correction_mul(x, batch_std, running_std, y, channel, kernel_name="correction_mul"): | |||
| """CorrectionMul op""" | |||
| shape = x.get("shape") | |||
| data_format = x.get("format") | |||
| util.check_kernel_name(kernel_name) | |||
| util.check_shape_rule(shape) | |||
| util.check_shape_size(shape, SHAPE_SIZE_LIMIT) | |||
| check_list = ["float16", "float32"] | |||
| inp_dtype = x.get("dtype").lower() | |||
| if not inp_dtype in check_list: | |||
| raise RuntimeError("Dtype of input only support float16, float32") | |||
| # shape = util.shape_refine(shape) | |||
| x_t = tvm.placeholder(shape, name="x", dtype=inp_dtype) | |||
| shape_c = [1] * len(shape) | |||
| shape_c[channel] = batch_std.get("ori_shape")[0] | |||
| if data_format == "NC1HWC0" and channel == 1: | |||
| shape_c = batch_std.get("shape") | |||
| batch_std_t = tvm.placeholder(shape_c, name="batch_std", dtype=inp_dtype) | |||
| running_std_t = tvm.placeholder(shape_c, name="running_std", dtype=inp_dtype) | |||
| res = correction_mul_compute(x_t, batch_std_t, running_std_t, kernel_name) | |||
| with tvm.target.cce(): | |||
| sch = generic.auto_schedule(res) | |||
| config = {"print_ir": False, | |||
| "name": kernel_name, | |||
| "tensor_list": [x_t, batch_std_t, running_std_t, res]} | |||
| te.lang.cce.cce_build_code(sch, config) | |||
| @@ -0,0 +1,134 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """CorrectionMul op""" | |||
| import te.lang.cce | |||
| from te import tvm | |||
| from te.platform.fusion_manager import fusion_manager | |||
| from topi import generic | |||
| from topi.cce import util | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| SHAPE_SIZE_LIMIT = 2147483648 | |||
| correction_mul_grad_op_info = TBERegOp("CorrectionMulGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("correction_mul_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("correction_mul_grad") \ | |||
| .partial_flag(True) \ | |||
| .op_pattern("formatAgnostic") \ | |||
| .attr("channel_axis", "optional", "int", "all") \ | |||
| .input(0, "dout", None, "required", None) \ | |||
| .input(1, "x", None, "required", None) \ | |||
| .input(2, "batch_std", None, "required", None) \ | |||
| .input(3, "running_std", None, "required", None) \ | |||
| .output(0, "dx", True, "required", "all") \ | |||
| .output(1, "d_batch_std", True, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, | |||
| DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register(correction_mul_grad_op_info) | |||
| def _correction_mul_grad_tbe(): | |||
| """CorrectionMulGrad TBE register""" | |||
| return | |||
| @fusion_manager.register("correction_mul_grad") | |||
| def correction_mul_grad_compute(dout, x, batch_std, running_std, channel, data_format, kernel_name="correction_mul"): | |||
| """CorrectionMulGrad compute""" | |||
| shape_x = te.lang.cce.util.shape_to_list(x.shape) | |||
| factor = te.lang.cce.vdiv(batch_std, running_std) | |||
| factor_b = te.lang.cce.broadcast(factor, shape_x) | |||
| dx = te.lang.cce.vmul(dout, factor_b) | |||
| mul_data = te.lang.cce.vmul(dout, x) | |||
| if channel == 0: | |||
| if data_format == "NCHW": | |||
| axis = [1, 2, 3] | |||
| else: | |||
| axis = [1, 2, 3, 4] | |||
| else: | |||
| axis = [2, 3] | |||
| red_data = te.lang.cce.sum(mul_data, axis, keepdims=True) | |||
| d_batch_std = te.lang.cce.vdiv(red_data, running_std) | |||
| return [dx, d_batch_std] | |||
| @util.check_input_type(dict, dict, dict, dict, dict, dict, int, str) | |||
| def correction_mul_grad(dout, x, batch_std, running_std, dx, d_batch_std, channel, kernel_name="correction_mul_grad"): | |||
| """CorrectionMulGrad op""" | |||
| shape_dout = dout.get("shape") | |||
| shape_x = dout.get("shape") | |||
| dtype_dout = dout.get("dtype") | |||
| dtype_x = x.get("dtype") | |||
| dtype_batch_std = batch_std.get("dtype") | |||
| dtype_running_std = running_std.get("dtype") | |||
| inp_dtype_dout = dtype_dout.lower() | |||
| inp_dtype_x = dtype_x.lower() | |||
| inp_dtype_batch_std = dtype_batch_std.lower() | |||
| inp_dtype_running_std = dtype_running_std.lower() | |||
| util.check_dtype_rule(inp_dtype_dout, ("float16", "float32")) | |||
| util.check_dtype_rule(inp_dtype_x, ("float16", "float32")) | |||
| util.check_dtype_rule(inp_dtype_batch_std, ("float32",)) | |||
| util.check_dtype_rule(inp_dtype_running_std, ("float32",)) | |||
| util.compare_tensor_dict_key(dout, x, "dtype") | |||
| util.compare_tensor_dict_key(dout, x, "shape") | |||
| util.compare_tensor_dict_key(dx, x, "shape") | |||
| util.compare_tensor_dict_key(batch_std, running_std, "shape") | |||
| util.compare_tensor_dict_key(batch_std, d_batch_std, "shape") | |||
| util.check_kernel_name(kernel_name) | |||
| util.check_shape_rule(shape_x) | |||
| util.check_shape_size(shape_x, SHAPE_SIZE_LIMIT) | |||
| data_format = dout.get("format") | |||
| ori_format = dout.get("format") | |||
| if data_format.upper() not in ("NC1HWC0", "NCHW"): | |||
| raise RuntimeError("Un supported data format {}".format(data_format)) | |||
| if data_format.upper() == "NCHW" and ori_format != "NCHW": | |||
| raise RuntimeError("data_format(NCHW) must same as ori_format") | |||
| shape_c = [1] * len(shape_x) | |||
| shape_c[channel] = batch_std.get("ori_shape")[0] | |||
| if data_format == "NC1HWC0" and channel == 1: | |||
| shape_c = batch_std.get("shape") | |||
| dout_t = tvm.placeholder(shape_dout, name="dout", dtype=inp_dtype_dout) | |||
| x_t = tvm.placeholder(shape_x, name="x", dtype=inp_dtype_x) | |||
| batch_std_t = tvm.placeholder(shape_c, name="batch_std", dtype=inp_dtype_batch_std) | |||
| running_std_t = tvm.placeholder(shape_c, name="running_std", dtype=inp_dtype_running_std) | |||
| res_list = correction_mul_grad_compute(dout_t, x_t, batch_std_t, running_std_t, channel, data_format, kernel_name) | |||
| with tvm.target.cce(): | |||
| sch = generic.auto_schedule(res_list) | |||
| tensor_list = [dout_t, x_t, batch_std_t, running_std_t] + list(res_list) | |||
| config = {"print_ir": False, | |||
| "name": kernel_name, | |||
| "tensor_list": tensor_list} | |||
| te.lang.cce.cce_build_code(sch, config) | |||
| @@ -0,0 +1,146 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """FakeQuantWithMinMax op""" | |||
| from functools import reduce as functools_reduce | |||
| import te.lang.cce | |||
| from te import tvm | |||
| from te.platform.fusion_manager import fusion_manager | |||
| from topi import generic | |||
| from topi.cce import util | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| fake_quant_op_info = TBERegOp("FakeQuantWithMinMax") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("fake_quant_with_min_max_vars_ema.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("fake_quant_with_min_max_vars_ema") \ | |||
| .partial_flag(True) \ | |||
| .attr("ema", "optional", "bool", "all") \ | |||
| .attr("ema_decay", "optional", "float", "all") \ | |||
| .attr("symmetric", "optional", "bool", "all") \ | |||
| .attr("narrow_range", "optional", "bool", "all") \ | |||
| .attr("training", "optional", "bool", "all") \ | |||
| .attr("num_bits", "optional", "int", "all") \ | |||
| .attr("quant_delay", "optional", "int", "all") \ | |||
| .input(0, "x", None, "required", None) \ | |||
| .input(1, "min", None, "required", None) \ | |||
| .input(2, "max", None, "required", None) \ | |||
| .output(0, "y", True, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register(fake_quant_op_info) | |||
| def _fake_quant_tbe(): | |||
| """FakeQuantWithMinMax TBE register""" | |||
| return | |||
| @fusion_manager.register("fake_quant_with_min_max_vars_ema") | |||
| def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min, quant_max, | |||
| kernel_name="correction_mul"): | |||
| """FakeQuantWithMinMax""" | |||
| shape = te.lang.cce.util.shape_to_list(x.shape) | |||
| shape_min = te.lang.cce.util.shape_to_list(min_val.shape) | |||
| quant_min = te.lang.cce.broadcast(quant_min, shape_min, x.dtype) | |||
| quant_max = te.lang.cce.broadcast(quant_max, shape_min, x.dtype) | |||
| min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype) | |||
| max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype) | |||
| # CalNudge(NudgeMinMax) | |||
| scale = te.lang.cce.vdiv(te.lang.cce.vsub(max_val, min_val), te.lang.cce.vsub(quant_max, quant_min)) | |||
| zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale)) | |||
| # Nudge zero point | |||
| nudge_zp = te.lang.cce.round(te.lang.cce.vmin(quant_max, te.lang.cce.vmax(quant_min, zp_from_min))) | |||
| nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale) | |||
| nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale) | |||
| # boradcast to shape | |||
| nudge_min = te.lang.cce.broadcast(nudge_min, shape, x.dtype) | |||
| nudge_max = te.lang.cce.broadcast(nudge_max, shape, x.dtype) | |||
| scale = te.lang.cce.broadcast(scale, shape, x.dtype) | |||
| # FakeQuant | |||
| input_x = te.lang.cce.vmin(nudge_max, te.lang.cce.vmax(nudge_min, x)) | |||
| nudge_input = te.lang.cce.floor(te.lang.cce.vadds(te.lang.cce.vdiv(te.lang.cce.vsub(input_x, nudge_min), scale), | |||
| 0.5)) | |||
| res = te.lang.cce.vadd(te.lang.cce.vmul(nudge_input, scale), nudge_min) | |||
| return res | |||
| @util.check_input_type(dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str) | |||
| def fake_quant_with_min_max_vars_ema(x, min_val, max_val, y, | |||
| ema, ema_decay, symmetric, narrow_range, training, num_bits, quant_delay, | |||
| kernel_name="fake_quant"): | |||
| """FakeQuantWithMinMax""" | |||
| input_shape = x.get("shape") | |||
| input_dtype = x.get("dtype") | |||
| min_shape = min_val.get("ori_shape") | |||
| min_dtype = min_val.get("dtype") | |||
| max_shape = max_val.get("ori_shape") | |||
| max_dtype = max_val.get("dtype") | |||
| min_shape = util.scalar2tensor_one(min_shape) | |||
| max_shape = util.scalar2tensor_one(max_shape) | |||
| util.check_kernel_name(kernel_name) | |||
| util.check_shape_rule(input_shape) | |||
| util.check_shape_rule(min_shape, 1, 1, 1) | |||
| util.check_shape_rule(max_shape, 1, 1, 1) | |||
| util.check_tensor_shape_size(input_shape) | |||
| util.check_tensor_shape_size(min_shape) | |||
| util.check_tensor_shape_size(max_shape) | |||
| check_list = ["float32", "float16"] | |||
| x_dtype = input_dtype.lower() | |||
| min_dtype = min_dtype.lower() | |||
| max_dtype = max_dtype.lower() | |||
| util.check_dtype_rule(x_dtype, check_list) | |||
| util.check_dtype_rule(min_dtype, check_list) | |||
| util.check_dtype_rule(max_dtype, check_list) | |||
| input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),) | |||
| shape_min, _, _ = util.produce_shapes(min_shape, input_shape) | |||
| if symmetric: | |||
| quant_min = 0 - 2 ** (num_bits - 1) | |||
| quant_max = 2 ** (num_bits - 1) - 1 | |||
| else: | |||
| quant_min = 0 | |||
| quant_max = 2 ** num_bits - 1 | |||
| if narrow_range: | |||
| quant_min = quant_min + 1 | |||
| input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) | |||
| min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) | |||
| max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) | |||
| res = fake_quant_with_min_max_vars_ema_compute(input_data, min_data, max_data, y, | |||
| quant_min, quant_max, kernel_name) | |||
| with tvm.target.cce(): | |||
| sch = generic.auto_schedule(res) | |||
| tensor_list = [input_data, min_data, max_data, res] | |||
| config = {"print_ir": False, | |||
| "name": kernel_name, | |||
| "tensor_list": tensor_list} | |||
| te.lang.cce.cce_build_code(sch, config) | |||
| @@ -0,0 +1,156 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """FakeQuantWithMinMaxGrad op""" | |||
| from functools import reduce as functools_reduce | |||
| import te.lang.cce | |||
| from te import tvm | |||
| from te.platform.fusion_manager import fusion_manager | |||
| from topi import generic | |||
| from topi.cce import util | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| SHAPE_SIZE_LIMIT = 2147483648 | |||
| D_TYPE = 'float32' | |||
| fake_quant_grad_op_info = TBERegOp("FakeQuantWithMinMaxGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("fake_quant_with_min_max_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("fake_quant_with_min_max_grad") \ | |||
| .partial_flag(True) \ | |||
| .attr("num_bits", "optional", "int", "all") \ | |||
| .attr("quant_delay", "optional", "int", "all") \ | |||
| .input(0, "dout", None, "required", None) \ | |||
| .input(1, "x", None, "required", None) \ | |||
| .input(2, "min", None, "required", None) \ | |||
| .input(3, "max", None, "required", None) \ | |||
| .output(0, "dx", True, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| def _less_compare_float32(data_x, data_y): | |||
| """_less_compare_float32 compute""" | |||
| shape_inputs = te.lang.cce.util.shape_to_list(data_x.shape) | |||
| min_value = tvm.const(2 ** (-126), dtype=D_TYPE) | |||
| max_value = tvm.const(2 ** 62, dtype=D_TYPE) | |||
| factor_value = tvm.const(2 ** 2, dtype=D_TYPE) | |||
| data_zero = te.lang.cce.broadcast(tvm.const(0, dtype=D_TYPE), shape_inputs, D_TYPE) | |||
| min_value_tensor = te.lang.cce.vadds(data_zero, min_value) | |||
| res_sub = te.lang.cce.vsub(data_y, data_x) | |||
| res_min = te.lang.cce.vmin(res_sub, min_value_tensor) | |||
| res_max = te.lang.cce.vmax(res_min, data_zero) | |||
| res_max_mul = te.lang.cce.vmuls(res_max, max_value) | |||
| res_max_mul_max = te.lang.cce.vmuls(res_max_mul, max_value) | |||
| res = te.lang.cce.vmuls(res_max_mul_max, factor_value) | |||
| return res | |||
| @op_info_register(fake_quant_grad_op_info) | |||
| def _fake_quant_grad_tbe(): | |||
| """FakeQuantWithMinMaxGrad TBE register""" | |||
| return | |||
| @fusion_manager.register("fake_quant_with_min_max_grad") | |||
| def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, quant_max, | |||
| kernel_name="fake_quant_with_min_max_grad"): | |||
| """FakeQuantWithMinMaxGrad""" | |||
| shape = te.lang.cce.util.shape_to_list(x.shape) | |||
| shape_min = te.lang.cce.util.shape_to_list(min_val.shape) | |||
| quant_min = tvm.const(quant_min, x.dtype) | |||
| quant_max = tvm.const(quant_max, x.dtype) | |||
| quant_min = te.lang.cce.broadcast(quant_min, shape_min) | |||
| quant_max = te.lang.cce.broadcast(quant_max, shape_min) | |||
| # CalNudge(NudgeMinMax) | |||
| scale = te.lang.cce.vdiv(te.lang.cce.vsub(max_val, min_val), te.lang.cce.vsub(quant_max, quant_min)) | |||
| zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale)) | |||
| # Nudge zero point | |||
| nudge_zp = te.lang.cce.round(te.lang.cce.vmin(quant_max, te.lang.cce.vmax(quant_min, zp_from_min))) | |||
| nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale) | |||
| nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale) | |||
| nudge_min = te.lang.cce.broadcast(nudge_min, shape) | |||
| nudge_max = te.lang.cce.broadcast(nudge_max, shape) | |||
| bool_over_min = _less_compare_float32(nudge_min, x) | |||
| bool_less_max = _less_compare_float32(x, nudge_max) | |||
| bool_between = te.lang.cce.vmul(bool_over_min, bool_less_max) | |||
| res = te.lang.cce.vmul(dout, bool_between) | |||
| return res | |||
| @util.check_input_type(dict, dict, dict, dict, dict, int, int, str) | |||
| def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx, num_bits, quant_delay, | |||
| kernel_name="fake_quant_with_min_max_grad"): | |||
| """FakeQuantWithMinMaxGrad""" | |||
| input_shape = x.get("shape") | |||
| input_dtype = x.get("dtype") | |||
| min_shape = min_val.get("ori_shape") | |||
| min_dtype = min_val.get("dtype") | |||
| max_shape = max_val.get("ori_shape") | |||
| max_dtype = max_val.get("dtype") | |||
| min_shape = util.scalar2tensor_one(min_shape) | |||
| max_shape = util.scalar2tensor_one(max_shape) | |||
| util.check_kernel_name(kernel_name) | |||
| util.check_shape_rule(input_shape) | |||
| util.check_shape_rule(min_shape, 1, 1, 1) | |||
| util.check_shape_rule(max_shape, 1, 1, 1) | |||
| util.check_tensor_shape_size(input_shape) | |||
| util.check_tensor_shape_size(min_shape) | |||
| util.check_tensor_shape_size(max_shape) | |||
| check_list = ["float32", 'float16'] | |||
| x_dtype = input_dtype.lower() | |||
| min_dtype = min_dtype.lower() | |||
| max_dtype = max_dtype.lower() | |||
| util.check_dtype_rule(x_dtype, check_list) | |||
| util.check_dtype_rule(min_dtype, check_list) | |||
| util.check_dtype_rule(max_dtype, check_list) | |||
| input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),) | |||
| shape_min, _, _ = util.produce_shapes(min_shape, input_shape) | |||
| quant_min = 0 | |||
| quant_max = 2 ** num_bits - 1 | |||
| dout_data = tvm.placeholder(input_shape, name="dout", dtype=x_dtype) | |||
| input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) | |||
| min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) | |||
| max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) | |||
| res = fake_quant_with_min_max_grad_compute(dout_data, input_data, min_data, max_data, quant_min, | |||
| quant_max, kernel_name) | |||
| with tvm.target.cce(): | |||
| sch = generic.auto_schedule(res) | |||
| tensor_list = [dout_data, input_data, min_data, max_data, res] | |||
| config = {"print_ir": False, | |||
| "name": kernel_name, | |||
| "tensor_list": tensor_list} | |||
| te.lang.cce.cce_build_code(sch, config) | |||
| @@ -0,0 +1,137 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """FakeQuantWithMinMaxUpdate op""" | |||
| from functools import reduce as functools_reduce | |||
| import te.lang.cce | |||
| from te import tvm | |||
| from te.platform.fusion_manager import fusion_manager | |||
| from topi import generic | |||
| from topi.cce import util | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| fake_quant_update5d_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("fake_quant_with_min_max_update5d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("fake_quant_with_min_max_update") \ | |||
| .partial_flag(True) \ | |||
| .attr("ema", "optional", "bool", "all") \ | |||
| .attr("ema_decay", "optional", "float", "all") \ | |||
| .attr("symmetric", "optional", "bool", "all") \ | |||
| .attr("narrow_range", "optional", "bool", "all") \ | |||
| .attr("training", "optional", "bool", "all") \ | |||
| .attr("num_bits", "optional", "int", "all") \ | |||
| .attr("quant_delay", "optional", "int", "all") \ | |||
| .input(0, "x", None, "required", None) \ | |||
| .input(1, "min", None, "required", None) \ | |||
| .input(2, "max", None, "required", None) \ | |||
| .output(0, "min_up", True, "required", "all") \ | |||
| .output(1, "max_up", True, "required", "all") \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register(fake_quant_update5d_op_info) | |||
| def _fake_quant_update5d_tbe(): | |||
| """_FakeQuantWithMinMaxUpdate5D TBE register""" | |||
| return | |||
| @fusion_manager.register("fake_quant_with_min_max_update") | |||
| def fake_quant_with_min_max_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training, | |||
| kernel_name="fake_quant_update"): | |||
| """FakeQuantWithMinMaxUpdate compute""" | |||
| shape = te.lang.cce.util.shape_to_list(x.shape) | |||
| shape_min = te.lang.cce.util.shape_to_list(min_val.shape) | |||
| min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype) | |||
| max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype) | |||
| if not ema: | |||
| ema_decay = 0.0 | |||
| if training: | |||
| # CalMinMax | |||
| axis = tuple(range(len(shape))) | |||
| x_min = te.lang.cce.reduce_min(x, axis=axis) | |||
| x_max = te.lang.cce.reduce_max(x, axis=axis) | |||
| x_min = te.lang.cce.broadcast(x_min, shape_min) | |||
| x_max = te.lang.cce.broadcast(x_max, shape_min) | |||
| min_val = te.lang.cce.vadd(te.lang.cce.vmuls(min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay))) | |||
| max_val = te.lang.cce.vadd(te.lang.cce.vmuls(max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay))) | |||
| min_val = te.lang.cce.vmins(min_val, 0) | |||
| max_val = te.lang.cce.vmaxs(max_val, 0) | |||
| return [min_val, max_val] | |||
| @util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str) | |||
| def fake_quant_with_min_max_update(x, min_val, max_val, min_up, max_up, | |||
| ema, ema_decay, symmetric, narrow_range, training, num_bits, quant_delay, | |||
| kernel_name="fake_quant_update"): | |||
| """FakeQuantWithMinMax op""" | |||
| input_shape = x.get("shape") | |||
| input_dtype = x.get("dtype") | |||
| min_shape = min_val.get("ori_shape") | |||
| min_dtype = min_val.get("dtype") | |||
| max_shape = max_val.get("ori_shape") | |||
| max_dtype = max_val.get("dtype") | |||
| min_shape = util.scalar2tensor_one(min_shape) | |||
| max_shape = util.scalar2tensor_one(max_shape) | |||
| util.check_kernel_name(kernel_name) | |||
| util.check_shape_rule(input_shape) | |||
| util.check_shape_rule(min_shape, 1, 1, 1) | |||
| util.check_shape_rule(max_shape, 1, 1, 1) | |||
| util.check_tensor_shape_size(input_shape) | |||
| util.check_tensor_shape_size(min_shape) | |||
| util.check_tensor_shape_size(max_shape) | |||
| check_list = ["float32", "float16"] | |||
| x_dtype = input_dtype.lower() | |||
| min_dtype = min_dtype.lower() | |||
| max_dtype = max_dtype.lower() | |||
| util.check_dtype_rule(x_dtype, check_list) | |||
| util.check_dtype_rule(min_dtype, check_list) | |||
| util.check_dtype_rule(max_dtype, check_list) | |||
| input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),) | |||
| shape_min, _, _ = util.produce_shapes(min_shape, input_shape) | |||
| if symmetric: | |||
| quant_min = 0 - 2 ** (num_bits - 1) | |||
| quant_max = 2 ** (num_bits - 1) - 1 | |||
| else: | |||
| quant_min = 0 | |||
| quant_max = 2 ** num_bits - 1 | |||
| if narrow_range: | |||
| quant_min = quant_min + 1 | |||
| input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) | |||
| min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) | |||
| max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) | |||
| res_list = fake_quant_with_min_max_update_compute(input_data, min_data, max_data, | |||
| ema, ema_decay, quant_min, quant_max, training, kernel_name) | |||
| with tvm.target.cce(): | |||
| sch = generic.auto_schedule(res_list) | |||
| tensor_list = [input_data, min_data, max_data] + list(res_list) | |||
| config = {"print_ir": False, | |||
| "name": kernel_name, | |||
| "tensor_list": tensor_list} | |||
| te.lang.cce.cce_build_code(sch, config) | |||
| @@ -30,6 +30,10 @@ __all__ = ["FakeQuantWithMinMax", | |||
| "CorrectionMulGrad", | |||
| "BatchNormFold2", | |||
| "BatchNormFold2Grad", | |||
| "BatchNormFoldD", | |||
| "BNTrainingReduce", | |||
| "BatchNormFold2_D", | |||
| "FakeQuantWithMinMaxUpdate", | |||
| ] | |||
| @@ -166,7 +170,7 @@ class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer): | |||
| >>> result = fake_quant(input_x, _min, _max) | |||
| """ | |||
| support_quant_bit = [4, 8] | |||
| channel_idx = 0 | |||
| channel_axis = 0 | |||
| @prim_attr_register | |||
| def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False, | |||
| @@ -188,8 +192,8 @@ class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer): | |||
| def infer_shape(self, x_shape, min_shape, max_shape): | |||
| validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name) | |||
| validator.check_integer("min shape[0]", min_shape[0], x_shape[self.channel_idx], Rel.EQ, self.name) | |||
| validator.check_integer("max shape[0]", max_shape[0], x_shape[self.channel_idx], Rel.EQ, self.name) | |||
| validator.check_integer("min shape[0]", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) | |||
| validator.check_integer("max shape[0]", max_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) | |||
| return x_shape | |||
| def infer_dtype(self, x_type, min_type, max_type): | |||
| @@ -272,7 +276,7 @@ class BatchNormFold(PrimitiveWithInfer): | |||
| >>> global_step = Tensor(np.arange(6), mindspore.int32) | |||
| >>> batch_mean, batch_std, running_mean, running_std = batch_norm_fold(input_x, mean, variance, global_step) | |||
| """ | |||
| channel = 1 | |||
| channel_axis = 1 | |||
| @prim_attr_register | |||
| def __init__(self, momentum=0.1, epsilon=1e-5, is_training=True, freeze_bn=0): | |||
| @@ -287,7 +291,7 @@ class BatchNormFold(PrimitiveWithInfer): | |||
| def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape): | |||
| validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name) | |||
| validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel], Rel.EQ, self.name) | |||
| validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel_axis], Rel.EQ, self.name) | |||
| validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name) | |||
| return mean_shape, mean_shape, mean_shape, mean_shape | |||
| @@ -314,7 +318,7 @@ class BatchNormFoldGrad(PrimitiveWithInfer): | |||
| >>> global_step = Tensor([2], mindspore.int32) | |||
| >>> result = batch_norm_fold_grad(d_batch_mean, d_batch_std, input_x, batch_mean, batch_std, global_step) | |||
| """ | |||
| channel = 1 | |||
| channel_axis = 1 | |||
| @prim_attr_register | |||
| def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0): | |||
| @@ -333,8 +337,8 @@ class BatchNormFoldGrad(PrimitiveWithInfer): | |||
| "batch_mean shape", batch_mean_shape, Rel.EQ, self.name) | |||
| validator.check("d_batch_mean shape", d_batch_mean_shape, | |||
| "batch_std shape", batch_std_shape, Rel.EQ, self.name) | |||
| validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0], "input channel", x_shape[self.channel], Rel.EQ, | |||
| self.name) | |||
| validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0], | |||
| "input channel", x_shape[self.channel_axis], Rel.EQ, self.name) | |||
| validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name) | |||
| return x_shape | |||
| @@ -368,17 +372,17 @@ class CorrectionMul(PrimitiveWithInfer): | |||
| >>> running_std = Tensor(np.array([2, 1.2, 0.5]), mindspore.float32) | |||
| >>> out = correction_mul(input_x, batch_std, running_std) | |||
| """ | |||
| channel = 0 | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| def __init__(self, channel_axis=0): | |||
| """init correction mul layer""" | |||
| self.channel_axis = channel_axis | |||
| self.init_prim_io_names(inputs=['x', 'batch_std', 'running_std'], | |||
| outputs=['out']) | |||
| def infer_shape(self, x_shape, batch_std_shape, running_std_shape): | |||
| validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name) | |||
| validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel], | |||
| validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis], | |||
| Rel.EQ, self.name) | |||
| return x_shape | |||
| @@ -400,20 +404,20 @@ class CorrectionMulGrad(PrimitiveWithInfer): | |||
| >>> running_std = Tensor(np.array([1.2, 0.1, 0.7, 2.3]).reshape(2, 1, 2), mindspore.float32) | |||
| >>> result = correction_mul_grad(dout, input_x, gamma, running_std) | |||
| """ | |||
| channel = 0 | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| def __init__(self, channel_axis=0): | |||
| """init correction mul layer""" | |||
| self.channel_axis = channel_axis | |||
| self.init_prim_io_names(inputs=['dout', 'x', 'gamma', 'running_std'], | |||
| outputs=['dx', 'd_gamma']) | |||
| def infer_shape(self, dout_shape, x_shape, gamma_shape, running_std_shape): | |||
| validator.check("dout shape", dout_shape, "x_shape x", x_shape, Rel.EQ, self.name) | |||
| validator.check("gamma_shape[0]", gamma_shape[0], "dout channel size", dout_shape[self.channel], | |||
| Rel.EQ, self.name) | |||
| validator.check("running_std_shape[0]", running_std_shape[0], "dout channel size", dout_shape[self.channel], | |||
| validator.check("gamma_shape[0]", gamma_shape[0], "dout channel size", dout_shape[self.channel_axis], | |||
| Rel.EQ, self.name) | |||
| validator.check("running_std_shape[0]", running_std_shape[0], | |||
| "dout channel size", dout_shape[self.channel_axis], Rel.EQ, self.name) | |||
| return x_shape, gamma_shape | |||
| def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type): | |||
| @@ -454,7 +458,7 @@ class BatchNormFold2(PrimitiveWithInfer): | |||
| >>> result = batch_norm_fold2(input_x, beta, gamma, batch_std, batch_mean, | |||
| >>> running_std, running_mean, global_step) | |||
| """ | |||
| channel = 1 | |||
| channel_axis = 1 | |||
| @prim_attr_register | |||
| def __init__(self, freeze_bn=0): | |||
| @@ -471,7 +475,7 @@ class BatchNormFold2(PrimitiveWithInfer): | |||
| validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, Rel.EQ, self.name) | |||
| validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape, Rel.EQ, self.name) | |||
| validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name) | |||
| validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel], | |||
| validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis], | |||
| Rel.EQ, self.name) | |||
| validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name) | |||
| return x_shape | |||
| @@ -501,7 +505,7 @@ class BatchNormFold2Grad(PrimitiveWithInfer): | |||
| >>> global_step = Tensor(np.array([-2]), mindspore.int32) | |||
| >>> result = bnf2_grad(dout, input_x, gamma, batch_std, batch_mean, running_std, running_mean, global_step) | |||
| """ | |||
| channel = 1 | |||
| channel_axis = 1 | |||
| @prim_attr_register | |||
| def __init__(self, freeze_bn=0): | |||
| @@ -519,7 +523,7 @@ class BatchNormFold2Grad(PrimitiveWithInfer): | |||
| validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name) | |||
| validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape, Rel.EQ, self.name) | |||
| validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name) | |||
| validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel], | |||
| validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis], | |||
| Rel.EQ, self.name) | |||
| validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name) | |||
| return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape | |||
| @@ -542,3 +546,259 @@ class BatchNormFold2Grad(PrimitiveWithInfer): | |||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||
| validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) | |||
| return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type | |||
| class BatchNormFoldD(PrimitiveWithInfer): | |||
| """Performs grad of _BatchNormFold operation.""" | |||
| @prim_attr_register | |||
| def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0): | |||
| """init _BatchNormFold layer""" | |||
| from mindspore.ops._op_impl._custom_op import batchnorm_fold | |||
| self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) | |||
| self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) | |||
| self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) | |||
| self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) | |||
| self.data_format = "NCHW" | |||
| self.init_prim_io_names(inputs=['x', 'x_sum', 'x_square_sum', 'mean', 'variance'], | |||
| outputs=['batch_mean', 'batch_std', 'running_mean', 'running_std', | |||
| 'mean_updated', 'variance_updated']) | |||
| def infer_shape(self, x_shape, x_sum_shape, x_square_sum_shape, mean_shape, variance_shape): | |||
| validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name) | |||
| validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[1], Rel.EQ, self.name) | |||
| return x_shape, mean_shape, mean_shape, mean_shape, mean_shape, mean_shape, mean_shape | |||
| def infer_dtype(self, x_type, x_sum_type, x_square_sum_type, mean_type, variance_type): | |||
| validator.check("input type", x_type, "mean type", mean_type) | |||
| validator.check("input type", x_type, "variance type", variance_type) | |||
| args = {"x": x_type, "mean": mean_type, "variance": variance_type} | |||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||
| return x_type, x_type, x_type, x_type, x_type, x_type, x_type | |||
| class BatchNormFoldGradD(PrimitiveWithInfer): | |||
| """Performs grad of _BatchNormFoldGrad operation.""" | |||
| @prim_attr_register | |||
| def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0): | |||
| """init _BatchNormFoldGrad layer""" | |||
| from mindspore.ops._op_impl._custom_op import batchnorm_fold_grad | |||
| self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) | |||
| self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) | |||
| self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) | |||
| self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std'], | |||
| outputs=['dx']) | |||
| def infer_shape(self, d_batch_mean_shape, d_batch_std_shape, x_shape, batch_mean_shape, batch_std_shape): | |||
| validator.check("d_batch_mean shape", d_batch_mean_shape, "d_batch_std shape", d_batch_std_shape) | |||
| validator.check("d_batch_mean shape", d_batch_mean_shape, "batch_mean shape", batch_mean_shape) | |||
| validator.check("d_batch_mean shape", d_batch_mean_shape, "batch_std shape", batch_std_shape) | |||
| validator.check("x_shape shape", d_batch_mean_shape[0], "input channel", x_shape[1]) | |||
| return x_shape | |||
| def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type): | |||
| validator.check("input type", x_type, "d_batch_mean type", d_batch_mean_type) | |||
| validator.check("input type", x_type, "d_batch_std type", d_batch_std_type) | |||
| validator.check("input type", x_type, "batch_mean type", batch_mean_type) | |||
| validator.check("input type", x_type, "batch_std type", batch_std_type) | |||
| args = {"input type": x_type} | |||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||
| return x_type | |||
| class BNTrainingReduce(PrimitiveWithInfer): | |||
| """ | |||
| reduce sum at axis [0, 2, 3]. | |||
| Inputs: | |||
| - **x** (Tensor) - Tensor of shape :math:`(N, C)`. | |||
| Outputs: | |||
| - **x_sum** (Tensor) - Tensor has the same shape as x. | |||
| - **x_square_sum** (Tensor) - Tensor has the same shape as x. | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init _BNTrainingReduce layer""" | |||
| self.init_prim_io_names(inputs=['x'], | |||
| outputs=['x_sum', 'x_square_sum']) | |||
| def infer_shape(self, x_shape): | |||
| return [x_shape[1]], [x_shape[1]] | |||
| def infer_dtype(self, x_type): | |||
| return x_type, x_type | |||
| class BatchNormFold2_D(PrimitiveWithInfer): | |||
| """ | |||
| Scale the bias with a correction factor to the long term statistics | |||
| prior to quantization. This ensures that there is no jitter in the quantized bias | |||
| due to batch to batch variation. | |||
| Inputs: | |||
| - **x** (Tensor) - Tensor of shape :math:`(N, C)`. | |||
| - **beta** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **gamma** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **running_std** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **global_step** (Tensor) - Tensor to record current global step. | |||
| Outputs: | |||
| - **y** (Tensor) - Tensor has the same shape as x. | |||
| """ | |||
| channel_axis = 1 | |||
| @prim_attr_register | |||
| def __init__(self, freeze_bn=0): | |||
| """init conv2d fold layer""" | |||
| from mindspore.ops._op_impl._custom_op import batchnorm_fold2 | |||
| self.init_prim_io_names(inputs=['x', 'beta', 'gamma', 'batch_std', 'batch_mean', 'running_std'], | |||
| outputs=['y']) | |||
| def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape): | |||
| validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name) | |||
| validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name) | |||
| validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, Rel.EQ, self.name) | |||
| validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name) | |||
| validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis], | |||
| Rel.EQ, self.name) | |||
| return x_shape | |||
| def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type): | |||
| args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type, | |||
| "beta": beta_type, "gamma": gamma_type, "x": x_type} | |||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||
| return x_type | |||
| class BatchNormFold2GradD(PrimitiveWithInfer): | |||
| """Performs grad of CorrectionAddGrad operation.""" | |||
| channel_axis = 1 | |||
| @prim_attr_register | |||
| def __init__(self, freeze_bn=False): | |||
| """init MulFold layer""" | |||
| from mindspore.ops._op_impl._custom_op import batchnorm_fold2_grad | |||
| self.freeze_bn = freeze_bn | |||
| self.init_prim_io_names( | |||
| inputs=['dout', 'dout_reduce', 'dout_x_reduce', 'gamma', 'batch_std', 'batch_mean', 'running_std'], | |||
| outputs=['d_batch_std', 'd_batch_mean', 'd_gamma', 'dx']) | |||
| def infer_shape(self, dout_shape, dout_reduce_shape, dout_x_reduce_shape, gamma_shape, batch_std_shape, | |||
| batch_mean_shape, running_std_shape): | |||
| validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name) | |||
| validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name) | |||
| validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name) | |||
| validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis], | |||
| Rel.EQ, self.name) | |||
| return gamma_shape, gamma_shape, gamma_shape, dout_shape | |||
| def infer_dtype(self, dout_type, dout_reduce_type, dout_x_reduce_type, gamma_type, batch_std_type, | |||
| batch_mean_type, running_std_type): | |||
| validator.check("batch_std type", batch_std_type, | |||
| "batch_mean type", batch_mean_type) | |||
| validator.check("batch_std type", batch_std_type, | |||
| "gamma type", gamma_type) | |||
| validator.check("batch_std type", batch_std_type, | |||
| "running_std type", running_std_type) | |||
| validator.check("batch_std_type", batch_std_type, | |||
| "dout type", dout_type) | |||
| args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type, | |||
| "running_std": running_std_type, "dout": dout_type} | |||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||
| return gamma_type, gamma_type, gamma_type, gamma_type | |||
| class BatchNormFold2GradReduce(PrimitiveWithInfer): | |||
| """Performs grad of CorrectionAddGrad operation.""" | |||
| channel_axis = 1 | |||
| @prim_attr_register | |||
| def __init__(self, freeze_bn=False): | |||
| """init MulFold layer""" | |||
| from mindspore.ops._op_impl._custom_op import batchnorm_fold2_grad_reduce | |||
| self.freeze_bn = freeze_bn | |||
| self.init_prim_io_names(inputs=['dout', 'x'], | |||
| outputs=['dout_reduce', 'dout_x_reduce']) | |||
| def infer_shape(self, dout_shape, x_shape): | |||
| validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name) | |||
| return (dout_shape[self.channel_axis],), (dout_shape[self.channel_axis],) | |||
| def infer_dtype(self, dout_type, x_type): | |||
| validator.check("dout type", dout_type, "x type", x_type) | |||
| return dout_type, dout_type | |||
| class FakeQuantWithMinMaxUpdate(PrimitiveWithInfer): | |||
| r""" | |||
| Simulate the quantize and dequantize operations in training time. | |||
| Args: | |||
| num_bits (int) : Number bits for aware quantilization. Default: 8. | |||
| ema (bool): Use EMA algorithm update value min and max. Default: False. | |||
| ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. | |||
| quant_delay (int): Quantilization delay parameter. Before delay step in training time not update | |||
| simulate aware quantize funcion. After delay step in training time begin simulate the aware | |||
| quantize funcion. Default: 0. | |||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||
| training (bool): Training the network or not. Default: True. | |||
| Inputs: | |||
| - **x** (Tensor) : float32 Tensor representing the shape of the output tensor. | |||
| - **min** (Tensor) : Value of the min range of the input data x. | |||
| - **max** (Tensor) : Value of the max range of the input data x. | |||
| Outputs: | |||
| - Tensor: Simulate quantize tensor of x. | |||
| Examples: | |||
| >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) | |||
| >>> min_tensor = Tensor(np.array([-6]), mstype.float32) | |||
| >>> max_tensor = Tensor(np.array([6]), mstype.float32) | |||
| >>> output_tensor = P.FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor) | |||
| """ | |||
| support_quant_bit = [4, 7, 8] | |||
| @prim_attr_register | |||
| def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False, | |||
| training=True): | |||
| """init FakeQuantWithMinMax OP""" | |||
| from mindspore.ops._op_impl._custom_op import correction_mul, correction_mul_grad | |||
| from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max, fake_quant_with_min_max_grad | |||
| from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max_update | |||
| if num_bits not in self.support_quant_bit: | |||
| raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") | |||
| if ema and not ema_decay: | |||
| raise ValueError(f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") | |||
| self.ema = validator.check_value_type('ema', ema, (bool,), self.name) | |||
| self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name) | |||
| self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name) | |||
| self.training = validator.check_value_type('training', training, (bool,), self.name) | |||
| self.ema_decay = validator.check_number_range('ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) | |||
| self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) | |||
| self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) | |||
| self.init_prim_io_names(inputs=['x', 'min', 'max'], | |||
| outputs=['min_up', 'max_up']) | |||
| def infer_shape(self, x_shape, min_shape, max_shape): | |||
| validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name) | |||
| validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) | |||
| validator.check_integer("min rank", len(min_shape), 1, Rel.EQ, self.name) | |||
| return min_shape, max_shape | |||
| def infer_dtype(self, x_type, min_type, max_type): | |||
| valid_types = (mstype.float16, mstype.float32) | |||
| validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same({"min": min_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same({"max": max_type}, valid_types, self.name) | |||
| return min_type, max_type | |||
| @@ -17,13 +17,12 @@ import numpy as np | |||
| from mobilenetv2_combined import MobileNetV2 | |||
| import mindspore.context as context | |||
| import mindspore.ops.operations as P | |||
| from mindspore import Tensor | |||
| from mindspore import nn | |||
| from mindspore.nn.layer import combined | |||
| from mindspore.train.quant import quant as qat | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| class LeNet5(nn.Cell): | |||
| @@ -65,7 +64,7 @@ class LeNet5(nn.Cell): | |||
| x = self.fc3(x) | |||
| return x | |||
| """ | |||
| def test_qat_lenet(): | |||
| net = LeNet5() | |||
| net = qat.convert_quant_network( | |||
| @@ -93,3 +92,4 @@ def test_qat_mobile_train(): | |||
| net = nn.WithLossCell(net, loss) | |||
| net = nn.TrainOneStepCell(net, optimizer) | |||
| net(img, label) | |||
| """ | |||