Merge pull request !2194 from 王东旭/mastertags/v0.5.0-beta
| @@ -214,7 +214,7 @@ class BatchNormFoldCell(Cell): | |||||
| Batch normalization folded. | Batch normalization folded. | ||||
| Args: | Args: | ||||
| momentum (float): Momentum value should be [0, 1]. Default: 0.1. | |||||
| momentum (float): Momentum value should be [0, 1]. Default: 0.9. | |||||
| epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in | epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in | ||||
| float32 else 1e-3. Default: 1e-5. | float32 else 1e-3. Default: 1e-5. | ||||
| freeze_bn (int): Delay in steps at which computation switches from regular batch | freeze_bn (int): Delay in steps at which computation switches from regular batch | ||||
| @@ -280,6 +280,7 @@ class FakeQuantWithMinMax(Cell): | |||||
| ema (bool): Exponential Moving Average algorithm update min and max. Default: False. | ema (bool): Exponential Moving Average algorithm update min and max. Default: False. | ||||
| ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. | ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. | ||||
| per_channel (bool): Quantization by layer or channel. Default: False. | per_channel (bool): Quantization by layer or channel. Default: False. | ||||
| channel_axis (int): Quantization by channel axis. Default: 1. | |||||
| out_channels (int): declarate the min and max channel size, Default: 1. | out_channels (int): declarate the min and max channel size, Default: 1. | ||||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | quant_delay (int): Quantization delay parameters according by global step. Default: 0. | ||||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | ||||
| @@ -391,17 +392,17 @@ class Conv2dBatchNormQuant(Cell): | |||||
| pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". | pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". | ||||
| padding: (int): Implicit paddings on both sides of the input. Default: 0. | padding: (int): Implicit paddings on both sides of the input. Default: 0. | ||||
| eps (int): Parameters for BatchNormal. Default: 1e-5. | eps (int): Parameters for BatchNormal. Default: 1e-5. | ||||
| momentum (int): Parameters for BatchNormal op. Default: 0.9. | |||||
| momentum (int): Parameters for BatchNormal op. Default: 0.997. | |||||
| weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the | weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the | ||||
| convolution kernel. Default: 'None'. | |||||
| convolution kernel. Default: 'normal'. | |||||
| beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the | beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the | ||||
| beta vector. Default: 'None'. | |||||
| beta vector. Default: 'zeros'. | |||||
| gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the | gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the | ||||
| gamma vector. Default: 'None'. | |||||
| gamma vector. Default: 'ones'. | |||||
| mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the | mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the | ||||
| mean vector. Default: 'None'. | |||||
| mean vector. Default: 'zeros'. | |||||
| var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the | var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the | ||||
| variance vector. Default: 'None'. | |||||
| variance vector. Default: 'ones'. | |||||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | quant_delay (int): Quantization delay parameters according by global step. Default: 0. | ||||
| freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000. | freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000. | ||||
| fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True. | fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True. | ||||
| @@ -434,11 +435,11 @@ class Conv2dBatchNormQuant(Cell): | |||||
| group=1, | group=1, | ||||
| eps=1e-5, | eps=1e-5, | ||||
| momentum=0.997, | momentum=0.997, | ||||
| weight_init=None, | |||||
| beta_init=None, | |||||
| gamma_init=None, | |||||
| mean_init=None, | |||||
| var_init=None, | |||||
| weight_init='normal', | |||||
| beta_init='zeros', | |||||
| gamma_init='ones', | |||||
| mean_init='zeros', | |||||
| var_init='ones', | |||||
| quant_delay=0, | quant_delay=0, | ||||
| freeze_bn=100000, | freeze_bn=100000, | ||||
| fake=True, | fake=True, | ||||
| @@ -477,8 +478,7 @@ class Conv2dBatchNormQuant(Cell): | |||||
| pad=padding, | pad=padding, | ||||
| stride=self.stride, | stride=self.stride, | ||||
| dilation=self.dilation) | dilation=self.dilation) | ||||
| if weight_init is None: | |||||
| weight_init = initializer('normal', [1, in_channels, *self.kernel_size]) | |||||
| weight_shape = [1, in_channels, *self.kernel_size] | |||||
| channel_axis = 1 | channel_axis = 1 | ||||
| else: | else: | ||||
| self.conv = P.Conv2D(out_channel=out_channels, | self.conv = P.Conv2D(out_channel=out_channels, | ||||
| @@ -488,24 +488,16 @@ class Conv2dBatchNormQuant(Cell): | |||||
| stride=self.stride, | stride=self.stride, | ||||
| dilation=self.dilation, | dilation=self.dilation, | ||||
| group=group) | group=group) | ||||
| if weight_init is None: | |||||
| weight_init = initializer('normal', [out_channels, in_channels // group, *self.kernel_size]) | |||||
| weight_shape = [out_channels, in_channels // group, *self.kernel_size] | |||||
| channel_axis = 0 | channel_axis = 0 | ||||
| self.weight = Parameter(weight_init, name='weight') | |||||
| self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') | |||||
| # initialize batchnorm Parameter | # initialize batchnorm Parameter | ||||
| if gamma_init is None: | |||||
| gamma_init = initializer('ones', [out_channels]) | |||||
| self.gamma = Parameter(gamma_init, name='gamma') | |||||
| if beta_init is None: | |||||
| beta_init = initializer('zeros', [out_channels]) | |||||
| self.beta = Parameter(beta_init, name='beta') | |||||
| if mean_init is None: | |||||
| mean_init = initializer('zeros', [out_channels]) | |||||
| self.moving_mean = Parameter(mean_init, name='moving_mean', requires_grad=False) | |||||
| if var_init is None: | |||||
| var_init = initializer('ones', [out_channels]) | |||||
| self.moving_variance = Parameter(var_init, name='moving_variance', requires_grad=False) | |||||
| self.gamma = Parameter(initializer(gamma_init, [out_channels]), name='gamma') | |||||
| self.beta = Parameter(initializer(beta_init, [out_channels]), name='beta') | |||||
| self.moving_mean = Parameter(initializer(mean_init, [out_channels]), name='moving_mean', requires_grad=False) | |||||
| self.moving_variance = Parameter(initializer(var_init, [out_channels]), name='moving_variance', | |||||
| requires_grad=False) | |||||
| # initialize fake ops | # initialize fake ops | ||||
| self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, | self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, | ||||
| @@ -588,8 +580,8 @@ class Conv2dQuant(Cell): | |||||
| divisible by the number of groups. Default: 1. | divisible by the number of groups. Default: 1. | ||||
| has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. | has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. | ||||
| weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. | weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. | ||||
| Default: None. | |||||
| bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: None. | |||||
| Default: 'normal'. | |||||
| bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'. | |||||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | quant_delay (int): Quantization delay parameters according by global step. Default: 0. | ||||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | ||||
| per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. | per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. | ||||
| @@ -619,8 +611,8 @@ class Conv2dQuant(Cell): | |||||
| dilation=1, | dilation=1, | ||||
| group=1, | group=1, | ||||
| has_bias=False, | has_bias=False, | ||||
| weight_init=None, | |||||
| bias_init=None, | |||||
| weight_init='normal', | |||||
| bias_init='zeros', | |||||
| quant_delay=0, | quant_delay=0, | ||||
| num_bits=8, | num_bits=8, | ||||
| per_channel=False, | per_channel=False, | ||||
| @@ -641,15 +633,14 @@ class Conv2dQuant(Cell): | |||||
| self.group = group | self.group = group | ||||
| self.quant_delay = quant_delay | self.quant_delay = quant_delay | ||||
| if weight_init is None: | |||||
| weight_init = initializer( | |||||
| 'normal', [out_channels, in_channels // group, *self.kernel_size]) | |||||
| self.weight = Parameter(weight_init, name='weight') | |||||
| if bias_init is None: | |||||
| bias_init = initializer('zeros', [out_channels]) | |||||
| if has_bias: | |||||
| self.bias = Parameter(bias_init, name='bias') | |||||
| self.bias_add = P.BiasAdd() | |||||
| weight_shape = [out_channels, in_channels // group, *self.kernel_size] | |||||
| self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') | |||||
| self.bias_add = P.BiasAdd() | |||||
| if check_bool(has_bias): | |||||
| self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias') | |||||
| else: | |||||
| self.bias = None | |||||
| self.conv = P.Conv2D(out_channel=self.out_channels, | self.conv = P.Conv2D(out_channel=self.out_channels, | ||||
| kernel_size=self.kernel_size, | kernel_size=self.kernel_size, | ||||
| @@ -738,8 +729,8 @@ class DenseQuant(Cell): | |||||
| self.has_bias = check_bool(has_bias) | self.has_bias = check_bool(has_bias) | ||||
| if isinstance(weight_init, Tensor): | if isinstance(weight_init, Tensor): | ||||
| if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ | |||||
| weight_init.shape[1] != in_channels: | |||||
| if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \ | |||||
| weight_init.shape()[1] != in_channels: | |||||
| raise ValueError("weight_init shape error") | raise ValueError("weight_init shape error") | ||||
| self.weight = Parameter(initializer( | self.weight = Parameter(initializer( | ||||
| @@ -747,7 +738,7 @@ class DenseQuant(Cell): | |||||
| if self.has_bias: | if self.has_bias: | ||||
| if isinstance(bias_init, Tensor): | if isinstance(bias_init, Tensor): | ||||
| if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: | |||||
| if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels: | |||||
| raise ValueError("bias_init shape error") | raise ValueError("bias_init shape error") | ||||
| self.bias = Parameter(initializer( | self.bias = Parameter(initializer( | ||||
| @@ -65,7 +65,6 @@ def batchnorm_fold(x, x_sum, x_square_sum, mean, variance, | |||||
| momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0, data_format="NCHW", | momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0, data_format="NCHW", | ||||
| kernel_name="batchnorm_fold"): | kernel_name="batchnorm_fold"): | ||||
| """batchnorm_fold TBE op""" | """batchnorm_fold TBE op""" | ||||
| momentum = 1.0 - momentum | |||||
| util.check_kernel_name(kernel_name) | util.check_kernel_name(kernel_name) | ||||
| data_format = data_format.upper() | data_format = data_format.upper() | ||||
| if data_format != "NCHW": | if data_format != "NCHW": | ||||
| @@ -120,13 +119,12 @@ def batchnorm_fold(x, x_sum, x_square_sum, mean, variance, | |||||
| variance_div = te.lang.cce.vmuls(x_square_sum, num_rec) | variance_div = te.lang.cce.vmuls(x_square_sum, num_rec) | ||||
| mean_square = te.lang.cce.vmul(batch_mean, batch_mean) | mean_square = te.lang.cce.vmul(batch_mean, batch_mean) | ||||
| batch_var_biased = te.lang.cce.vsub(variance_div, mean_square) | batch_var_biased = te.lang.cce.vsub(variance_div, mean_square) | ||||
| batch_std = te.lang.cce.vsqrt(te.lang.cce.vadds(batch_var_biased, epsilon)) | |||||
| if num == 1: | if num == 1: | ||||
| batch_var_scaler = 0.0 | batch_var_scaler = 0.0 | ||||
| else: | else: | ||||
| batch_var_scaler = float(num) / (num - 1) | batch_var_scaler = float(num) / (num - 1) | ||||
| batch_variance = te.lang.cce.vmuls(batch_var_biased, batch_var_scaler) | |||||
| batch_std = te.lang.cce.vsqrt(te.lang.cce.vadds(batch_variance, epsilon)) | |||||
| batch_var_unbiased = te.lang.cce.vmuls(batch_var_biased, batch_var_scaler) | |||||
| factor = 1.0 - momentum | factor = 1.0 - momentum | ||||
| factor_reverse = momentum | factor_reverse = momentum | ||||
| @@ -134,7 +132,7 @@ def batchnorm_fold(x, x_sum, x_square_sum, mean, variance, | |||||
| mean_mul_rev = te.lang.cce.vmuls(mean, factor_reverse) | mean_mul_rev = te.lang.cce.vmuls(mean, factor_reverse) | ||||
| mean_updated = te.lang.cce.vadd(mean_mul, mean_mul_rev) | mean_updated = te.lang.cce.vadd(mean_mul, mean_mul_rev) | ||||
| var_mul = te.lang.cce.vmuls(batch_variance, factor) | |||||
| var_mul = te.lang.cce.vmuls(batch_var_unbiased, factor) | |||||
| var_mul_rev = te.lang.cce.vmuls(variance, factor_reverse) | var_mul_rev = te.lang.cce.vmuls(variance, factor_reverse) | ||||
| variance_updated = te.lang.cce.vadd(var_mul, var_mul_rev) | variance_updated = te.lang.cce.vadd(var_mul, var_mul_rev) | ||||
| @@ -50,15 +50,16 @@ def _fake_quant_per_layer_tbe(): | |||||
| @fusion_manager.register("fake_quant_per_layer") | @fusion_manager.register("fake_quant_per_layer") | ||||
| def fake_quant_per_layer_compute(x, min_val, max_val, y, quant_min, quant_max, | |||||
| def fake_quant_per_layer_compute(x, min_val, max_val, y, quant_min, quant_max, symmetric, | |||||
| kernel_name="fake_quant_per_layer"): | kernel_name="fake_quant_per_layer"): | ||||
| """FakeQuantPerLayer""" | """FakeQuantPerLayer""" | ||||
| shape = te.lang.cce.util.shape_to_list(x.shape) | shape = te.lang.cce.util.shape_to_list(x.shape) | ||||
| shape_min = te.lang.cce.util.shape_to_list(min_val.shape) | shape_min = te.lang.cce.util.shape_to_list(min_val.shape) | ||||
| quant_min = te.lang.cce.broadcast(quant_min, shape_min, x.dtype) | quant_min = te.lang.cce.broadcast(quant_min, shape_min, x.dtype) | ||||
| quant_max = te.lang.cce.broadcast(quant_max, shape_min, x.dtype) | quant_max = te.lang.cce.broadcast(quant_max, shape_min, x.dtype) | ||||
| min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype) | |||||
| max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype) | |||||
| if symmetric: | |||||
| max_val = te.lang.cce.vmax(te.lang.cce.vmuls(min_val, -1.), max_val) | |||||
| min_val = te.lang.cce.vmuls(max_val, -1.) | |||||
| # CalNudge(NudgeMinMax) | # CalNudge(NudgeMinMax) | ||||
| scale = te.lang.cce.vdiv(te.lang.cce.vsub( | scale = te.lang.cce.vdiv(te.lang.cce.vsub( | ||||
| @@ -119,12 +120,8 @@ def fake_quant_per_layer(x, min_val, max_val, y, | |||||
| input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),) | input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),) | ||||
| shape_min, _, _ = util.produce_shapes(min_shape, input_shape) | shape_min, _, _ = util.produce_shapes(min_shape, input_shape) | ||||
| if symmetric: | |||||
| quant_min = 0 - 2 ** (num_bits - 1) | |||||
| quant_max = 2 ** (num_bits - 1) - 1 | |||||
| else: | |||||
| quant_min = 0 | |||||
| quant_max = 2 ** num_bits - 1 | |||||
| quant_min = 0 | |||||
| quant_max = 2 ** num_bits - 1 | |||||
| if narrow_range: | if narrow_range: | ||||
| quant_min = quant_min + 1 | quant_min = quant_min + 1 | ||||
| @@ -132,7 +129,7 @@ def fake_quant_per_layer(x, min_val, max_val, y, | |||||
| min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) | min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) | ||||
| max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) | max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) | ||||
| res = fake_quant_per_layer_compute(input_data, min_data, max_data, y, | res = fake_quant_per_layer_compute(input_data, min_data, max_data, y, | ||||
| quant_min, quant_max, kernel_name) | |||||
| quant_min, quant_max, symmetric, kernel_name) | |||||
| with tvm.target.cce(): | with tvm.target.cce(): | ||||
| sch = generic.auto_schedule(res) | sch = generic.auto_schedule(res) | ||||
| @@ -78,7 +78,7 @@ def _fake_quant_per_layer_grad_tbe(): | |||||
| @fusion_manager.register("fake_quant_per_layer_grad") | @fusion_manager.register("fake_quant_per_layer_grad") | ||||
| def fake_quant_per_layer_grad_compute(dout, x, min_val, max_val, quant_min, quant_max, | |||||
| def fake_quant_per_layer_grad_compute(dout, x, min_val, max_val, quant_min, quant_max, symmetric, | |||||
| kernel_name="fake_quant_per_layer_grad"): | kernel_name="fake_quant_per_layer_grad"): | ||||
| """FakeQuantPerLayerGrad""" | """FakeQuantPerLayerGrad""" | ||||
| shape = te.lang.cce.util.shape_to_list(x.shape) | shape = te.lang.cce.util.shape_to_list(x.shape) | ||||
| @@ -88,6 +88,10 @@ def fake_quant_per_layer_grad_compute(dout, x, min_val, max_val, quant_min, quan | |||||
| quant_min = te.lang.cce.broadcast(quant_min, shape_min) | quant_min = te.lang.cce.broadcast(quant_min, shape_min) | ||||
| quant_max = te.lang.cce.broadcast(quant_max, shape_min) | quant_max = te.lang.cce.broadcast(quant_max, shape_min) | ||||
| if symmetric: | |||||
| max_val = te.lang.cce.vmax(te.lang.cce.vmuls(min_val, -1.), max_val) | |||||
| min_val = te.lang.cce.vmuls(max_val, -1.) | |||||
| # CalNudge(NudgeMinMax) | # CalNudge(NudgeMinMax) | ||||
| scale = te.lang.cce.vdiv(te.lang.cce.vsub( | scale = te.lang.cce.vdiv(te.lang.cce.vsub( | ||||
| max_val, min_val), te.lang.cce.vsub(quant_max, quant_min)) | max_val, min_val), te.lang.cce.vsub(quant_max, quant_min)) | ||||
| @@ -142,12 +146,8 @@ def fake_quant_per_layer_grad(dout, x, min_val, max_val, dx, | |||||
| input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),) | input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),) | ||||
| shape_min, _, _ = util.produce_shapes(min_shape, input_shape) | shape_min, _, _ = util.produce_shapes(min_shape, input_shape) | ||||
| if symmetric: | |||||
| quant_min = 0 - 2 ** (num_bits - 1) | |||||
| quant_max = 2 ** (num_bits - 1) - 1 | |||||
| else: | |||||
| quant_min = 0 | |||||
| quant_max = 2 ** num_bits - 1 | |||||
| quant_min = 0 | |||||
| quant_max = 2 ** num_bits - 1 | |||||
| if narrow_range: | if narrow_range: | ||||
| quant_min = quant_min + 1 | quant_min = quant_min + 1 | ||||
| @@ -155,8 +155,8 @@ def fake_quant_per_layer_grad(dout, x, min_val, max_val, dx, | |||||
| input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) | input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) | ||||
| min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) | min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) | ||||
| max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) | max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) | ||||
| res = fake_quant_per_layer_grad_compute(dout_data, input_data, min_data, max_data, quant_min, | |||||
| quant_max, kernel_name) | |||||
| res = fake_quant_per_layer_grad_compute(dout_data, input_data, min_data, max_data, | |||||
| quant_min, quant_max, symmetric, kernel_name) | |||||
| with tvm.target.cce(): | with tvm.target.cce(): | ||||
| sch = generic.auto_schedule(res) | sch = generic.auto_schedule(res) | ||||
| @@ -58,7 +58,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl | |||||
| BiasAdd, Conv2D, | BiasAdd, Conv2D, | ||||
| DepthwiseConv2dNative, | DepthwiseConv2dNative, | ||||
| DropoutDoMask, DropoutGrad, Dropout, | DropoutDoMask, DropoutGrad, Dropout, | ||||
| DropoutGenMask, Flatten, FusedBatchNorm, | |||||
| DropoutGenMask, Flatten, FusedBatchNorm, BNTrainingReduce, BNTrainingUpdate, | |||||
| Gelu, Elu, | Gelu, Elu, | ||||
| GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, | GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, | ||||
| LogSoftmax, | LogSoftmax, | ||||
| @@ -76,7 +76,6 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl | |||||
| ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK) | ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK) | ||||
| from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, | from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, | ||||
| CheckValid, MakeRefKey, Partial, Depend, CheckBprop) | CheckValid, MakeRefKey, Partial, Depend, CheckBprop) | ||||
| from . import _quant_ops | |||||
| from ._quant_ops import * | from ._quant_ops import * | ||||
| from .thor_ops import * | from .thor_ops import * | ||||
| @@ -101,6 +100,9 @@ __all__ = [ | |||||
| 'Conv2D', | 'Conv2D', | ||||
| 'Flatten', | 'Flatten', | ||||
| 'MaxPoolWithArgmax', | 'MaxPoolWithArgmax', | ||||
| 'FusedBatchNorm', | |||||
| 'BNTrainingReduce', | |||||
| 'BNTrainingUpdate', | |||||
| 'BatchNorm', | 'BatchNorm', | ||||
| 'MaxPool', | 'MaxPool', | ||||
| 'TopK', | 'TopK', | ||||
| @@ -311,5 +313,4 @@ __all__ = [ | |||||
| "InTopK" | "InTopK" | ||||
| ] | ] | ||||
| __all__.extend(_quant_ops.__all__) | |||||
| __all__.sort() | __all__.sort() | ||||
| @@ -36,7 +36,6 @@ __all__ = ["FakeQuantPerLayer", | |||||
| "BatchNormFold2Grad", | "BatchNormFold2Grad", | ||||
| "BatchNormFoldD", | "BatchNormFoldD", | ||||
| "BatchNormFoldGradD", | "BatchNormFoldGradD", | ||||
| "BNTrainingReduce", | |||||
| "BatchNormFold2_D", | "BatchNormFold2_D", | ||||
| "BatchNormFold2GradD", | "BatchNormFold2GradD", | ||||
| "BatchNormFold2GradReduce", | "BatchNormFold2GradReduce", | ||||
| @@ -334,7 +333,7 @@ class BatchNormFold(PrimitiveWithInfer): | |||||
| Batch normalization folded. | Batch normalization folded. | ||||
| Args: | Args: | ||||
| momentum (float): Momentum value should be [0, 1]. Default: 0.1. | |||||
| momentum (float): Momentum value should be [0, 1]. Default: 0.9. | |||||
| epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in | epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in | ||||
| float32 else 1e-3. Default: 1e-5. | float32 else 1e-3. Default: 1e-5. | ||||
| is_training (bool): In training mode set True, else set False. Default: True. | is_training (bool): In training mode set True, else set False. Default: True. | ||||
| @@ -366,7 +365,7 @@ class BatchNormFold(PrimitiveWithInfer): | |||||
| channel_axis = 1 | channel_axis = 1 | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, momentum=0.1, epsilon=1e-5, is_training=True, freeze_bn=0): | |||||
| def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0): | |||||
| """init batch norm fold layer""" | """init batch norm fold layer""" | ||||
| self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) | self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) | ||||
| self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) | self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) | ||||
| @@ -731,32 +730,6 @@ class BatchNormFoldGradD(PrimitiveWithInfer): | |||||
| return x_type | return x_type | ||||
| class BNTrainingReduce(PrimitiveWithInfer): | |||||
| """ | |||||
| reduce sum at axis [0, 2, 3]. | |||||
| Inputs: | |||||
| - **x** (Tensor) - Tensor of shape :math:`(N, C)`. | |||||
| Outputs: | |||||
| - **x_sum** (Tensor) - Tensor has the same shape as x. | |||||
| - **x_square_sum** (Tensor) - Tensor has the same shape as x. | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init _BNTrainingReduce layer""" | |||||
| self.init_prim_io_names(inputs=['x'], | |||||
| outputs=['x_sum', 'x_square_sum']) | |||||
| def infer_shape(self, x_shape): | |||||
| return [x_shape[1]], [x_shape[1]] | |||||
| def infer_dtype(self, x_type): | |||||
| return x_type, x_type | |||||
| class BatchNormFold2_D(PrimitiveWithInfer): | class BatchNormFold2_D(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Scale the bias with a correction factor to the long term statistics | Scale the bias with a correction factor to the long term statistics | ||||
| @@ -585,6 +585,50 @@ class FusedBatchNorm(Primitive): | |||||
| self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) | self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) | ||||
| class BNTrainingReduce(PrimitiveWithInfer): | |||||
| """ | |||||
| reduce sum at axis [0, 2, 3]. | |||||
| Inputs: | |||||
| - **x** (Tensor) - Tensor of shape :math:`(N, C)`. | |||||
| Outputs: | |||||
| - **sum** (Tensor) - Tensor of shape :math:`(C,)`. | |||||
| - **square_sum** (Tensor) - Tensor of shape :math:`(C,)`. | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| self.init_prim_io_names(inputs=['x'], outputs=['sum', 'square_sum']) | |||||
| def infer_shape(self, x_shape): | |||||
| validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) | |||||
| return ([x_shape[1]], [x_shape[1]]) | |||||
| def infer_dtype(self, x_type): | |||||
| return (x_type, x_type) | |||||
| class BNTrainingUpdate(PrimitiveWithInfer): | |||||
| """ | |||||
| primitive operator of bn_training_update's register and info descriptor | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, isRef=True, epsilon=1e-5, factor=0.1): | |||||
| self.init_prim_io_names(inputs=['x', 'sum', 'square_sum', 'scale', 'b', 'mean', 'variance'], | |||||
| outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance']) | |||||
| #self.isRef = validator.check_integer('isRef', isRef, [0, 1], Rel.IN) | |||||
| self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, 'BNTrainingUpdate') | |||||
| self.factor = validator.check_number_range('factor', factor, 0, 1, Rel.INC_BOTH, 'BNTrainingUpdate') | |||||
| def infer_shape(self, x, sum, square_sum, scale, b, mean, variance): | |||||
| return (x, variance, variance, variance, variance) | |||||
| def infer_dtype(self, x, sum, square_sum, scale, b, mean, variance): | |||||
| return (x, variance, variance, variance, variance) | |||||
| class BatchNorm(PrimitiveWithInfer): | class BatchNorm(PrimitiveWithInfer): | ||||
| r""" | r""" | ||||
| Batch Normalization for input data and updated parameters. | Batch Normalization for input data and updated parameters. | ||||
| @@ -28,7 +28,7 @@ context.set_context(device_target='GPU') | |||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(Net, self).__init__() | super(Net, self).__init__() | ||||
| self.op = P.BatchNormFold(freeze_bn=10) | |||||
| self.op = P.BatchNormFold(momentum=0.9, freeze_bn=10) | |||||
| @ms_function | @ms_function | ||||
| def construct(self, x, mean, variance, current_step): | def construct(self, x, mean, variance, current_step): | ||||
| @@ -40,8 +40,8 @@ def np_result(x, mean, var, momentum, epsilon): | |||||
| np_mean = x.mean(axis=(0, 2, 3)) | np_mean = x.mean(axis=(0, 2, 3)) | ||||
| np_var = x.var(axis=(0, 2, 3)) | np_var = x.var(axis=(0, 2, 3)) | ||||
| n = x.shape[0] * x.shape[2] * x.shape[3] | n = x.shape[0] * x.shape[2] * x.shape[3] | ||||
| mean_update = momentum * np_mean + (1 - momentum) * mean | |||||
| var_update = momentum * np_var * n / (n - 1) + (1 - momentum) * var | |||||
| mean_update = (1 - momentum) * np_mean + momentum * mean | |||||
| var_update = (1 - momentum) * np_var * n / (n - 1) + momentum * var | |||||
| np_var = np.sqrt(np_var + epsilon) | np_var = np.sqrt(np_var + epsilon) | ||||
| delay_mean = mean.copy() | delay_mean = mean.copy() | ||||
| delay_std = np.sqrt(var + epsilon) | delay_std = np.sqrt(var + epsilon) | ||||