|
|
|
@@ -29,6 +29,7 @@ from ...common import dtype as mstype |
|
|
|
from ...common.api import _executor |
|
|
|
from ...nn.layer import quant |
|
|
|
from ...ops import functional as F |
|
|
|
from ...ops import operations as P |
|
|
|
from ...ops.operations import _inner_ops as inner |
|
|
|
from ...train import serialization |
|
|
|
from . import quant_utils |
|
|
|
@@ -366,8 +367,6 @@ class ExportToQuantInferNetwork: |
|
|
|
sqrt_mode = True |
|
|
|
dequant_op = inner.Dequant(sqrt_mode) |
|
|
|
|
|
|
|
# get op |
|
|
|
op_core = cell_core.matmul if isinstance(cell_core, quant.DenseQuant) else cell_core.conv |
|
|
|
if isinstance(activation, _AddFakeQuantAfterSubCell): |
|
|
|
activation = activation.subcell |
|
|
|
elif hasattr(activation, "get_origin"): |
|
|
|
@@ -383,10 +382,17 @@ class ExportToQuantInferNetwork: |
|
|
|
weight, bias = quant_utils.fold_batchnorm(weight, cell_core) |
|
|
|
|
|
|
|
# apply the quant |
|
|
|
weight = Tensor(quant_utils.weight2int(weight, scale_w, zp_w), self.data_type) |
|
|
|
weight = quant_utils.weight2int(weight, scale_w, zp_w) |
|
|
|
if bias is not None: |
|
|
|
bias = Tensor(scale_a_in * scale_w * bias, mstype.int32) |
|
|
|
scale_deq = Tensor(scale_deq, mstype.float16) |
|
|
|
# get op |
|
|
|
if isinstance(cell_core, quant.DenseQuant): |
|
|
|
op_core = P.MatMul() |
|
|
|
weight = np.transpose(weight) |
|
|
|
else: |
|
|
|
op_core = cell_core.conv |
|
|
|
weight = Tensor(weight, self.data_type) |
|
|
|
block = quant.QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation) |
|
|
|
return block |
|
|
|
|
|
|
|
|