|
|
|
@@ -19,6 +19,7 @@ from collections import namedtuple |
|
|
|
import numpy as np |
|
|
|
from mindspore import nn |
|
|
|
import mindspore.common.dtype as mstype |
|
|
|
from mindspore.ops.primitive import Primitive |
|
|
|
from mindspore.ops import operations as P |
|
|
|
from mindspore.ops import functional as F |
|
|
|
from mindspore.common.parameter import Parameter |
|
|
|
@@ -85,7 +86,7 @@ class Conv2dBnAct(Cell): |
|
|
|
momentum (float): Momentum for moving average.Momentum value must be [0, 1].Default:0.9 |
|
|
|
eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default: |
|
|
|
1e-5. |
|
|
|
activation (Cell): Specifies activation type. The optional values are as following: |
|
|
|
activation (Union[str, Cell, Primitive]): Specifies activation type. The optional values are as following: |
|
|
|
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', |
|
|
|
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. |
|
|
|
alpha (float): Slope of the activation function at x < 0. Default: 0.2. |
|
|
|
@@ -143,7 +144,9 @@ class Conv2dBnAct(Cell): |
|
|
|
if activation == "leakyrelu": |
|
|
|
self.activation = LeakyReLU(alpha) |
|
|
|
else: |
|
|
|
self.activation = get_activation(activation) |
|
|
|
self.activation = get_activation(activation) if isinstance(activation, str) else activation |
|
|
|
if activation is not None and not isinstance(self.activation, (Cell, Primitive)): |
|
|
|
raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation)) |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
x = self.conv(x) |
|
|
|
@@ -170,7 +173,7 @@ class DenseBnAct(Cell): |
|
|
|
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. |
|
|
|
activation (Cell): The regularization function applied to the output of the layer, eg. 'ReLU'. Default: None. |
|
|
|
has_bn (bool): Specifies to use batchnorm or not. Default: False. |
|
|
|
activation (string): Specifies activation type. The optional values are as following: |
|
|
|
activation (Union[str, Cell, Primitive]): Specifies activation type. The optional values are as following: |
|
|
|
'Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid', |
|
|
|
'PReLU', 'LeakyReLU', 'h-Swish', and 'h-Sigmoid'. Default: None. |
|
|
|
after_fake(bool): Determin whether there must be a fake quantization operation after DenseBnAct. |
|
|
|
@@ -208,7 +211,9 @@ class DenseBnAct(Cell): |
|
|
|
self.after_fake = after_fake |
|
|
|
if has_bn: |
|
|
|
self.batchnorm = BatchNorm1d(out_channels) |
|
|
|
self.activation = get_activation(activation) |
|
|
|
self.activation = get_activation(activation) if isinstance(activation, str) else activation |
|
|
|
if activation is not None and not isinstance(self.activation, (Cell, Primitive)): |
|
|
|
raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation)) |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
x = self.dense(x) |
|
|
|
@@ -930,7 +935,8 @@ class DenseQuant(Cell): |
|
|
|
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is |
|
|
|
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. |
|
|
|
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. |
|
|
|
activation (str): The regularization function applied to the output of the layer, eg. 'relu'. Default: None. |
|
|
|
activation (Union[str, Cell, Primitive]): The regularization function applied to the output of the layer, |
|
|
|
eg. 'relu'. Default: None. |
|
|
|
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default. |
|
|
|
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8. |
|
|
|
|
|
|
|
@@ -979,7 +985,9 @@ class DenseQuant(Cell): |
|
|
|
self.matmul = P.MatMul(transpose_b=True) |
|
|
|
self.bias_add = P.BiasAdd() |
|
|
|
|
|
|
|
self.activation = get_activation(activation) |
|
|
|
self.activation = get_activation(activation) if isinstance(activation, str) else activation |
|
|
|
if activation is not None and not isinstance(self.activation, (Cell, Primitive)): |
|
|
|
raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation)) |
|
|
|
self.activation_flag = self.activation is not None |
|
|
|
self.fake_quant_weight = quant_config.weight(min_init=-6, |
|
|
|
max_init=6, |
|
|
|
|