Browse Source

move Conv2dBnAct,DenseBnAct to combined.py

tags/v1.1.0
yuchaojie 5 years ago
parent
commit
01303fa88e
5 changed files with 229 additions and 201 deletions
  1. +4
    -4
      mindspore/compression/export/quant_export.py
  2. +5
    -5
      mindspore/compression/quant/qat.py
  3. +3
    -1
      mindspore/nn/layer/__init__.py
  4. +215
    -0
      mindspore/nn/layer/combined.py
  5. +2
    -191
      mindspore/nn/layer/quant.py

+ 4
- 4
mindspore/compression/export/quant_export.py View File

@@ -181,11 +181,11 @@ class ExportToQuantInferNetwork:
cell_core = None cell_core = None
fake_quant_act = None fake_quant_act = None
activation = None activation = None
if isinstance(subcell, quant.Conv2dBnAct):
if isinstance(subcell, nn.Conv2dBnAct):
cell_core = subcell.conv cell_core = subcell.conv
activation = subcell.activation activation = subcell.activation
fake_quant_act = activation.fake_quant_act if hasattr(activation, "fake_quant_act") else None fake_quant_act = activation.fake_quant_act if hasattr(activation, "fake_quant_act") else None
elif isinstance(subcell, quant.DenseBnAct):
elif isinstance(subcell, nn.DenseBnAct):
cell_core = subcell.dense cell_core = subcell.dense
activation = subcell.activation activation = subcell.activation
fake_quant_act = activation.fake_quant_act if hasattr(activation, "fake_quant_act") else None fake_quant_act = activation.fake_quant_act if hasattr(activation, "fake_quant_act") else None
@@ -240,9 +240,9 @@ class ExportManualQuantNetwork(ExportToQuantInferNetwork):
subcell = cells[name] subcell = cells[name]
if subcell == network: if subcell == network:
continue continue
if isinstance(subcell, quant.Conv2dBnAct):
if isinstance(subcell, nn.Conv2dBnAct):
network, change = self._convert_subcell(network, change, name, subcell) network, change = self._convert_subcell(network, change, name, subcell)
elif isinstance(subcell, quant.DenseBnAct):
elif isinstance(subcell, nn.DenseBnAct):
network, change = self._convert_subcell(network, change, name, subcell, conv=False) network, change = self._convert_subcell(network, change, name, subcell, conv=False)
elif isinstance(subcell, (quant.Conv2dBnFoldQuant, quant.Conv2dBnWithoutFoldQuant, elif isinstance(subcell, (quant.Conv2dBnFoldQuant, quant.Conv2dBnWithoutFoldQuant,
quant.Conv2dQuant, quant.DenseQuant)): quant.Conv2dQuant, quant.DenseQuant)):


+ 5
- 5
mindspore/compression/quant/qat.py View File

@@ -36,7 +36,7 @@ from .quantizer import Quantizer, OptimizeOption
__all__ = ["QuantizationAwareTraining", "create_quant_config"] __all__ = ["QuantizationAwareTraining", "create_quant_config"]




def create_quant_config(quant_observer=(quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver),
def create_quant_config(quant_observer=(nn.FakeQuantWithMinMaxObserver, nn.FakeQuantWithMinMaxObserver),
quant_delay=(0, 0), quant_delay=(0, 0),
quant_dtype=(QuantDtype.INT8, QuantDtype.INT8), quant_dtype=(QuantDtype.INT8, QuantDtype.INT8),
per_channel=(False, False), per_channel=(False, False),
@@ -48,7 +48,7 @@ def create_quant_config(quant_observer=(quant.FakeQuantWithMinMaxObserver, quant
Args: Args:
quant_observer (Observer, list or tuple): The oberser type to do quantization. The first element represent quant_observer (Observer, list or tuple): The oberser type to do quantization. The first element represent
weights and second element represent data flow. weights and second element represent data flow.
Default: (quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver)
Default: (nn.FakeQuantWithMinMaxObserver, nn.FakeQuantWithMinMaxObserver)
quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during
eval. The first element represent weights and second element represent data flow. Default: (0, 0) eval. The first element represent weights and second element represent data flow. Default: (0, 0)
quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first
@@ -210,8 +210,8 @@ class QuantizationAwareTraining(Quantizer):
self.act_symmetric = Validator.check_bool(symmetric[-1], "symmetric") self.act_symmetric = Validator.check_bool(symmetric[-1], "symmetric")
self.weight_range = Validator.check_bool(narrow_range[0], "narrow range") self.weight_range = Validator.check_bool(narrow_range[0], "narrow range")
self.act_range = Validator.check_bool(narrow_range[-1], "narrow range") self.act_range = Validator.check_bool(narrow_range[-1], "narrow range")
self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv,
quant.DenseBnAct: self._convert_dense}
self._convert_method_map = {nn.Conv2dBnAct: self._convert_conv,
nn.DenseBnAct: self._convert_dense}
self.quant_config = create_quant_config(quant_delay=quant_delay, self.quant_config = create_quant_config(quant_delay=quant_delay,
quant_dtype=quant_dtype, quant_dtype=quant_dtype,
per_channel=per_channel, per_channel=per_channel,
@@ -257,7 +257,7 @@ class QuantizationAwareTraining(Quantizer):
subcell = cells[name] subcell = cells[name]
if subcell == network: if subcell == network:
continue continue
elif isinstance(subcell, (quant.Conv2dBnAct, quant.DenseBnAct)):
elif isinstance(subcell, (nn.Conv2dBnAct, nn.DenseBnAct)):
prefix = subcell.param_prefix prefix = subcell.param_prefix
new_subcell = self._convert_method_map[type(subcell)](subcell) new_subcell = self._convert_method_map[type(subcell)](subcell)
new_subcell.update_parameters_name(prefix + '.') new_subcell.update_parameters_name(prefix + '.')


+ 3
- 1
mindspore/nn/layer/__init__.py View File

@@ -17,7 +17,7 @@ Layer.


The high-level components(Cells) used to construct the neural network. The high-level components(Cells) used to construct the neural network.
""" """
from . import activation, normalization, container, conv, lstm, basic, embedding, pooling, image, quant, math
from . import activation, normalization, container, conv, lstm, basic, embedding, pooling, image, quant, math, combined
from .activation import * from .activation import *
from .normalization import * from .normalization import *
from .container import * from .container import *
@@ -29,6 +29,7 @@ from .pooling import *
from .image import * from .image import *
from .quant import * from .quant import *
from .math import * from .math import *
from .combined import *


__all__ = [] __all__ = []
__all__.extend(activation.__all__) __all__.extend(activation.__all__)
@@ -42,3 +43,4 @@ __all__.extend(pooling.__all__)
__all__.extend(image.__all__) __all__.extend(image.__all__)
__all__.extend(quant.__all__) __all__.extend(quant.__all__)
__all__.extend(math.__all__) __all__.extend(math.__all__)
__all__.extend(combined.__all__)

+ 215
- 0
mindspore/nn/layer/combined.py View File

@@ -0,0 +1,215 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Combined cells."""

from mindspore import nn
from mindspore.ops.primitive import Primitive
from mindspore._checkparam import Validator
from .normalization import BatchNorm2d, BatchNorm1d
from .activation import get_activation, LeakyReLU
from ..cell import Cell


__all__ = [
'Conv2dBnAct',
'DenseBnAct'
]


class Conv2dBnAct(Cell):
r"""
A combination of convolution, Batchnorm, activation layer.

This part is a more detailed overview of Conv2d op.

Args:
in_channels (int): The number of input channel :math:`C_{in}`.
out_channels (int): The number of output channel :math:`C_{out}`.
kernel_size (Union[int, tuple]): The data type is int or a tuple of 2 integers. Specifies the height
and width of the 2D convolution window. Single int means the value is for both height and width of
the kernel. A tuple of 2 ints means the first value is for the height and the other is for the
width of the kernel.
stride (int): Specifies stride for all spatial dimensions with the same value. The value of stride must be
greater than or equal to 1 and lower than any one of the height and width of the input. Default: 1.
pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
padding (int): Implicit paddings on both sides of the input. Default: 0.
dilation (int): Specifies the dilation rate to use for dilated convolution. If set to be :math:`k > 1`,
there will be :math:`k - 1` pixels skipped for each sampling location. Its value must be greater than
or equal to 1 and lower than any one of the height and width of the input. Default: 1.
group (int): Splits filter into groups, `in_ channels` and `out_channels` must be
divisible by the number of groups. Default: 1.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
It can be a Tensor, a string, an Initializer or a number. When a string is specified,
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
Initializer for more details. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
Initializer and string are the same as 'weight_init'. Refer to the values of
Initializer for more details. Default: 'zeros'.
has_bn (bool): Specifies to used batchnorm or not. Default: False.
momentum (float): Momentum for moving average for batchnorm, must be [0, 1]. Default:0.9
eps (float): Term added to the denominator to improve numerical stability for batchnorm, should be greater
than 0. Default: 1e-5.
activation (Union[str, Cell, Primitive]): Specifies activation type. The optional values are as following:
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
alpha (float): Slope of the activation function at x < 0 for LeakyReLU. Default: 0.2.
after_fake(bool): Determine whether there must be a fake quantization operation after Cond2dBnAct.

Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.

Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.

Examples:
>>> net = nn.Conv2dBnAct(120, 240, 4, has_bn=True, activation='ReLU')
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
>>> result = net(input)
>>> result.shape
(1, 240, 1024, 640)
"""

def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
has_bias=False,
weight_init='normal',
bias_init='zeros',
has_bn=False,
momentum=0.9,
eps=1e-5,
activation=None,
alpha=0.2,
after_fake=True):
super(Conv2dBnAct, self).__init__()

self.conv = nn.Conv2d(in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
pad_mode=pad_mode,
padding=padding,
dilation=dilation,
group=group,
has_bias=has_bias,
weight_init=weight_init,
bias_init=bias_init)
self.has_bn = Validator.check_bool(has_bn, "has_bn")
self.has_act = activation is not None
self.after_fake = Validator.check_bool(after_fake, "after_fake")
if has_bn:
self.batchnorm = BatchNorm2d(out_channels, eps, momentum)
if activation == "leakyrelu":
self.activation = LeakyReLU(alpha)
else:
self.activation = get_activation(activation) if isinstance(activation, str) else activation
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))

def construct(self, x):
x = self.conv(x)
if self.has_bn:
x = self.batchnorm(x)
if self.has_act:
x = self.activation(x)
return x


class DenseBnAct(Cell):
r"""
A combination of Dense, Batchnorm, and the activation layer.

This part is a more detailed overview of Dense op.

Args:
in_channels (int): The number of channels in the input space.
out_channels (int): The number of channels in the output space.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as input. The values of str refer to the function `initializer`. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as input. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (Cell): The regularization function applied to the output of the layer, eg. 'ReLU'. Default: None.
has_bn (bool): Specifies to use batchnorm or not. Default: False.
momentum (float): Momentum for moving average for batchnorm, must be [0, 1]. Default:0.9
eps (float): Term added to the denominator to improve numerical stability for batchnorm, should be greater
than 0. Default: 1e-5.
activation (Union[str, Cell, Primitive]): Specifies activation type. The optional values are as following:
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
alpha (float): Slope of the activation function at x < 0 for LeakyReLU. Default: 0.2.
after_fake(bool): Determine whether there must be a fake quantization operation after DenseBnAct.

Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.

Outputs:
Tensor of shape :math:`(N, out\_channels)`.

Examples:
>>> net = nn.DenseBnAct(3, 4)
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
>>> result = net(input)
>>> result.shape
(2, 4)
"""

def __init__(self,
in_channels,
out_channels,
weight_init='normal',
bias_init='zeros',
has_bias=True,
has_bn=False,
momentum=0.9,
eps=1e-5,
activation=None,
alpha=0.2,
after_fake=True):
super(DenseBnAct, self).__init__()
self.dense = nn.Dense(
in_channels,
out_channels,
weight_init,
bias_init,
has_bias)
self.has_bn = Validator.check_bool(has_bn, "has_bn")
self.has_act = activation is not None
self.after_fake = Validator.check_bool(after_fake, "after_fake")
if has_bn:
self.batchnorm = BatchNorm1d(out_channels, eps, momentum)
if activation == "leakyrelu":
self.activation = LeakyReLU(alpha)
else:
self.activation = get_activation(activation) if isinstance(activation, str) else activation
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))

def construct(self, x):
x = self.dense(x)
if self.has_bn:
x = self.batchnorm(x)
if self.has_act:
x = self.activation(x)
return x

+ 2
- 191
mindspore/nn/layer/quant.py View File

@@ -17,7 +17,6 @@
from functools import partial from functools import partial
from collections import namedtuple from collections import namedtuple
import numpy as np import numpy as np
from mindspore import nn
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.ops.primitive import Primitive from mindspore.ops.primitive import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
@@ -28,14 +27,12 @@ from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator, Rel, twice from mindspore._checkparam import Validator, Rel, twice
from mindspore.compression.common import QuantDtype from mindspore.compression.common import QuantDtype
import mindspore.context as context import mindspore.context as context
from .normalization import BatchNorm2d, BatchNorm1d
from .activation import get_activation, ReLU, LeakyReLU
from .normalization import BatchNorm2d
from .activation import get_activation, ReLU
from ..cell import Cell from ..cell import Cell
from ...ops.operations import _quant_ops as Q from ...ops.operations import _quant_ops as Q


__all__ = [ __all__ = [
'Conv2dBnAct',
'DenseBnAct',
'FakeQuantWithMinMaxObserver', 'FakeQuantWithMinMaxObserver',
'Conv2dBnFoldQuant', 'Conv2dBnFoldQuant',
'Conv2dBnWithoutFoldQuant', 'Conv2dBnWithoutFoldQuant',
@@ -47,192 +44,6 @@ __all__ = [
] ]




class Conv2dBnAct(Cell):
r"""
A combination of convolution, Batchnorm, activation layer.

This part is a more detailed overview of Conv2d op.

Args:
in_channels (int): The number of input channel :math:`C_{in}`.
out_channels (int): The number of output channel :math:`C_{out}`.
kernel_size (Union[int, tuple]): The data type is int or a tuple of 2 integers. Specifies the height
and width of the 2D convolution window. Single int means the value is for both height and width of
the kernel. A tuple of 2 ints means the first value is for the height and the other is for the
width of the kernel.
stride (int): Specifies stride for all spatial dimensions with the same value. The value of stride must be
greater than or equal to 1 and lower than any one of the height and width of the input. Default: 1.
pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
padding (int): Implicit paddings on both sides of the input. Default: 0.
dilation (int): Specifies the dilation rate to use for dilated convolution. If set to be :math:`k > 1`,
there will be :math:`k - 1` pixels skipped for each sampling location. Its value must be greater than
or equal to 1 and lower than any one of the height and width of the input. Default: 1.
group (int): Splits filter into groups, `in_ channels` and `out_channels` must be
divisible by the number of groups. Default: 1.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
It can be a Tensor, a string, an Initializer or a number. When a string is specified,
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
Initializer for more details. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
Initializer and string are the same as 'weight_init'. Refer to the values of
Initializer for more details. Default: 'zeros'.
has_bn (bool): Specifies to used batchnorm or not. Default: False.
momentum (float): Momentum for moving average for batchnorm, must be [0, 1]. Default:0.9
eps (float): Term added to the denominator to improve numerical stability for batchnorm, should be greater
than 0. Default: 1e-5.
activation (Union[str, Cell, Primitive]): Specifies activation type. The optional values are as following:
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
alpha (float): Slope of the activation function at x < 0 for LeakyReLU. Default: 0.2.
after_fake(bool): Determine whether there must be a fake quantization operation after Cond2dBnAct.

Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.

Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.

Examples:
>>> net = nn.Conv2dBnAct(120, 240, 4, has_bn=True, activation='ReLU')
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
>>> result = net(input)
>>> result.shape
(1, 240, 1024, 640)
"""

def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
has_bias=False,
weight_init='normal',
bias_init='zeros',
has_bn=False,
momentum=0.9,
eps=1e-5,
activation=None,
alpha=0.2,
after_fake=True):
super(Conv2dBnAct, self).__init__()

self.conv = nn.Conv2d(in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
pad_mode=pad_mode,
padding=padding,
dilation=dilation,
group=group,
has_bias=has_bias,
weight_init=weight_init,
bias_init=bias_init)
self.has_bn = Validator.check_bool(has_bn, "has_bn")
self.has_act = activation is not None
self.after_fake = Validator.check_bool(after_fake, "after_fake")
if has_bn:
self.batchnorm = BatchNorm2d(out_channels, eps, momentum)
if activation == "leakyrelu":
self.activation = LeakyReLU(alpha)
else:
self.activation = get_activation(activation) if isinstance(activation, str) else activation
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))

def construct(self, x):
x = self.conv(x)
if self.has_bn:
x = self.batchnorm(x)
if self.has_act:
x = self.activation(x)
return x


class DenseBnAct(Cell):
r"""
A combination of Dense, Batchnorm, and the activation layer.

This part is a more detailed overview of Dense op.

Args:
in_channels (int): The number of channels in the input space.
out_channels (int): The number of channels in the output space.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as input. The values of str refer to the function `initializer`. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as input. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (Cell): The regularization function applied to the output of the layer, eg. 'ReLU'. Default: None.
has_bn (bool): Specifies to use batchnorm or not. Default: False.
momentum (float): Momentum for moving average for batchnorm, must be [0, 1]. Default:0.9
eps (float): Term added to the denominator to improve numerical stability for batchnorm, should be greater
than 0. Default: 1e-5.
activation (Union[str, Cell, Primitive]): Specifies activation type. The optional values are as following:
'Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid',
'PReLU', 'LeakyReLU', 'h-Swish', and 'h-Sigmoid'. Default: None.
alpha (float): Slope of the activation function at x < 0 for LeakyReLU. Default: 0.2.
after_fake(bool): Determine whether there must be a fake quantization operation after DenseBnAct.

Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.

Outputs:
Tensor of shape :math:`(N, out\_channels)`.

Examples:
>>> net = nn.DenseBnAct(3, 4)
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
>>> result = net(input)
>>> result.shape
(2, 4)
"""

def __init__(self,
in_channels,
out_channels,
weight_init='normal',
bias_init='zeros',
has_bias=True,
has_bn=False,
momentum=0.9,
eps=1e-5,
activation=None,
alpha=0.2,
after_fake=True):
super(DenseBnAct, self).__init__()
self.dense = nn.Dense(
in_channels,
out_channels,
weight_init,
bias_init,
has_bias)
self.has_bn = Validator.check_bool(has_bn, "has_bn")
self.has_act = activation is not None
self.after_fake = Validator.check_bool(after_fake, "after_fake")
if has_bn:
self.batchnorm = BatchNorm1d(out_channels, eps, momentum)
if activation == "leakyrelu":
self.activation = LeakyReLU(alpha)
self.activation = get_activation(activation) if isinstance(activation, str) else activation
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))

def construct(self, x):
x = self.dense(x)
if self.has_bn:
x = self.batchnorm(x)
if self.has_act:
x = self.activation(x)
return x


class BatchNormFoldCell(Cell): class BatchNormFoldCell(Cell):
""" """
Batch normalization folded. Batch normalization folded.


Loading…
Cancel
Save