GitOrigin-RevId: edefbec7b7
tags/v1.3.0
| @@ -641,6 +641,7 @@ class DeformableConv2d(_ConvNd): | |||||
| bias: bool = True, | bias: bool = True, | ||||
| conv_mode: str = "CROSS_CORRELATION", | conv_mode: str = "CROSS_CORRELATION", | ||||
| compute_mode: str = "DEFAULT", | compute_mode: str = "DEFAULT", | ||||
| **kwargs | |||||
| ): | ): | ||||
| kernel_size = _pair_nonzero(kernel_size) | kernel_size = _pair_nonzero(kernel_size) | ||||
| stride = _pair_nonzero(stride) | stride = _pair_nonzero(stride) | ||||
| @@ -657,6 +658,7 @@ class DeformableConv2d(_ConvNd): | |||||
| dilation, | dilation, | ||||
| groups, | groups, | ||||
| bias, | bias, | ||||
| **kwargs, | |||||
| ) | ) | ||||
| def _get_fanin(self): | def _get_fanin(self): | ||||
| @@ -21,8 +21,9 @@ class DeformablePSROIPooling(Module): | |||||
| sample_per_part, | sample_per_part, | ||||
| spatial_scale, | spatial_scale, | ||||
| trans_std: float = 0.1, | trans_std: float = 0.1, | ||||
| **kwargs | |||||
| ): | ): | ||||
| super().__init__() | |||||
| super().__init__(**kwargs) | |||||
| self.no_trans = no_trans | self.no_trans = no_trans | ||||
| self.part_size = part_size | self.part_size = part_size | ||||
| self.pooled_h = pooled_h | self.pooled_h = pooled_h | ||||
| @@ -69,7 +69,17 @@ class Module(metaclass=ABCMeta): | |||||
| Base Module class. | Base Module class. | ||||
| """ | """ | ||||
| def __init__(self, name=""): | |||||
| def __init__(self, name=None): | |||||
| """ | |||||
| :param name: module's name, can be initialized by the ``kwargs`` parameter | |||||
| of child class. | |||||
| """ | |||||
| if name is not None: | |||||
| assert ( | |||||
| isinstance(name, str) and name.strip() | |||||
| ), "Module's name must be a non-empty string" | |||||
| self.name = name | self.name = name | ||||
| # runtime attributes | # runtime attributes | ||||
| @@ -109,7 +119,7 @@ class Module(metaclass=ABCMeta): | |||||
| return HookHandler(self._forward_hooks, hook) | return HookHandler(self._forward_hooks, hook) | ||||
| def __call__(self, *inputs, **kwargs): | def __call__(self, *inputs, **kwargs): | ||||
| auto_naming.push_scope(self.name if self.name else self._name) | |||||
| auto_naming.push_scope(self.name if self.name is not None else self._name) | |||||
| for hook in self._forward_pre_hooks.values(): | for hook in self._forward_pre_hooks.values(): | ||||
| modified_inputs = hook(self, inputs) | modified_inputs = hook(self, inputs) | ||||
| if modified_inputs is not None: | if modified_inputs is not None: | ||||
| @@ -28,6 +28,7 @@ class BatchMatMulActivation(Float.BatchMatMulActivation, QATModule): | |||||
| float_module.in_features, | float_module.in_features, | ||||
| float_module.out_features, | float_module.out_features, | ||||
| float_module.bias is not None, | float_module.bias is not None, | ||||
| name=float_module.name, | |||||
| ) | ) | ||||
| qat_module.weight = float_module.weight | qat_module.weight = float_module.weight | ||||
| qat_module.bias = float_module.bias | qat_module.bias = float_module.bias | ||||
| @@ -27,4 +27,4 @@ class Concat(Float.Concat, QATModule): | |||||
| Return a :class:`~.QATModule` instance converted from | Return a :class:`~.QATModule` instance converted from | ||||
| a float :class:`~.Module` instance. | a float :class:`~.Module` instance. | ||||
| """ | """ | ||||
| return cls() | |||||
| return cls(name=float_module.name) | |||||
| @@ -43,6 +43,7 @@ class Conv2d(Float.Conv2d, QATModule): | |||||
| float_module.bias is not None, | float_module.bias is not None, | ||||
| float_module.conv_mode, | float_module.conv_mode, | ||||
| float_module.compute_mode, | float_module.compute_mode, | ||||
| name=float_module.name, | |||||
| ) | ) | ||||
| qat_module.weight = float_module.weight | qat_module.weight = float_module.weight | ||||
| qat_module.bias = float_module.bias | qat_module.bias = float_module.bias | ||||
| @@ -155,6 +155,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): | |||||
| float_module.conv.bias is not None, | float_module.conv.bias is not None, | ||||
| float_module.conv.conv_mode, | float_module.conv.conv_mode, | ||||
| float_module.conv.compute_mode, | float_module.conv.compute_mode, | ||||
| name=float_module.name, | |||||
| ) | ) | ||||
| qat_module.conv.weight = float_module.conv.weight | qat_module.conv.weight = float_module.conv.weight | ||||
| qat_module.conv.bias = float_module.conv.bias | qat_module.conv.bias = float_module.conv.bias | ||||
| @@ -28,4 +28,4 @@ class Elemwise(Float.Elemwise, QATModule): | |||||
| Return a :class:`~.QATModule` instance converted from | Return a :class:`~.QATModule` instance converted from | ||||
| a float :class:`~.Module` instance. | a float :class:`~.Module` instance. | ||||
| """ | """ | ||||
| return cls(float_module.method) | |||||
| return cls(float_module.method, name=float_module.name) | |||||
| @@ -36,7 +36,9 @@ class Linear(Float.Linear, QATModule): | |||||
| Return a :class:`~.QATModule` instance converted from | Return a :class:`~.QATModule` instance converted from | ||||
| a float :class:`~.Module` instance. | a float :class:`~.Module` instance. | ||||
| """ | """ | ||||
| qmod = cls(float_module.in_features, float_module.out_features) | |||||
| qmod = cls( | |||||
| float_module.in_features, float_module.out_features, name=float_module.name | |||||
| ) | |||||
| qmod.weight = float_module.weight | qmod.weight = float_module.weight | ||||
| qmod.bias = float_module.bias | qmod.bias = float_module.bias | ||||
| return qmod | return qmod | ||||
| @@ -26,8 +26,8 @@ class QATModule(Module): | |||||
| with_weight = True | with_weight = True | ||||
| with_act = True | with_act = True | ||||
| def __init__(self): | |||||
| super().__init__() | |||||
| def __init__(self, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.weight_observer = None # type: Observer | self.weight_observer = None # type: Observer | ||||
| self.act_observer = None # type: Observer | self.act_observer = None # type: Observer | ||||
| @@ -26,7 +26,7 @@ class QuantStub(Float.QuantStub, QATModule): | |||||
| Return a :class:`~.QATModule` instance converted from | Return a :class:`~.QATModule` instance converted from | ||||
| a float :class:`~.Module` instance. | a float :class:`~.Module` instance. | ||||
| """ | """ | ||||
| return cls() | |||||
| return cls(name=float_module.name) | |||||
| class DequantStub(Float.DequantStub, QATModule): | class DequantStub(Float.DequantStub, QATModule): | ||||
| @@ -47,4 +47,4 @@ class DequantStub(Float.DequantStub, QATModule): | |||||
| Return a :class:`~.QATModule` instance converted from | Return a :class:`~.QATModule` instance converted from | ||||
| a float :class:`~.Module` instance. | a float :class:`~.Module` instance. | ||||
| """ | """ | ||||
| return cls() | |||||
| return cls(name=float_module.name) | |||||
| @@ -61,13 +61,14 @@ class BatchMatMulActivation(Float.BatchMatMulActivation, QuantizedModule): | |||||
| qat_module.out_features, | qat_module.out_features, | ||||
| qat_module.bias is not None, | qat_module.bias is not None, | ||||
| dtype=output_dtype, | dtype=output_dtype, | ||||
| name=qat_module.name, | |||||
| ) | ) | ||||
| weight = qat_module.weight.astype(qat_module.get_weight_dtype()) | weight = qat_module.weight.astype(qat_module.get_weight_dtype()) | ||||
| weight = expand_dims(weight, [-1, -2]) | weight = expand_dims(weight, [-1, -2]) | ||||
| qbmm.weight = Parameter(weight.numpy()) | |||||
| qbmm.weight = Parameter(weight.numpy(), name=qat_module.weight.name) | |||||
| if qat_module.bias is not None: | if qat_module.bias is not None: | ||||
| bias = qat_module.bias.reshape((1, qbmm.out_features, 1, 1)) | bias = qat_module.bias.reshape((1, qbmm.out_features, 1, 1)) | ||||
| qbmm.bias = Parameter(bias.numpy()) | |||||
| qbmm.bias = Parameter(bias.numpy(), name=qat_module.bias.name) | |||||
| else: | else: | ||||
| qbmm.bias = Parameter( | qbmm.bias = Parameter( | ||||
| np.zeros((1, qbmm.out_features, 1, 1), dtype=np.float32) | np.zeros((1, qbmm.out_features, 1, 1), dtype=np.float32) | ||||
| @@ -18,8 +18,8 @@ class Concat(QuantizedModule): | |||||
| A :class:`~.QuantizedModule` to do quantized :func:`~.concat`, used for inference only. | A :class:`~.QuantizedModule` to do quantized :func:`~.concat`, used for inference only. | ||||
| """ | """ | ||||
| def __init__(self, dtype=None): | |||||
| super().__init__() | |||||
| def __init__(self, dtype=None, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.output_dtype = dtype | self.output_dtype = dtype | ||||
| def forward(self, inps: Iterable[Tensor], axis: int = 0): | def forward(self, inps: Iterable[Tensor], axis: int = 0): | ||||
| @@ -32,4 +32,4 @@ class Concat(QuantizedModule): | |||||
| Return a :class:`~.QuantizedModule` instance converted from a | Return a :class:`~.QuantizedModule` instance converted from a | ||||
| :class:`~.QATModule` instance. | :class:`~.QATModule` instance. | ||||
| """ | """ | ||||
| return cls(qat_module.get_activation_dtype()) | |||||
| return cls(qat_module.get_activation_dtype(), name=qat_module.name) | |||||
| @@ -37,6 +37,7 @@ class Conv2d(Float.Conv2d, QuantizedModule): | |||||
| conv_mode: str = "CROSS_CORRELATION", | conv_mode: str = "CROSS_CORRELATION", | ||||
| compute_mode: str = "DEFAULT", | compute_mode: str = "DEFAULT", | ||||
| dtype=None, | dtype=None, | ||||
| **kwargs | |||||
| ): | ): | ||||
| super().__init__( | super().__init__( | ||||
| in_channels, | in_channels, | ||||
| @@ -86,11 +87,12 @@ class Conv2d(Float.Conv2d, QuantizedModule): | |||||
| qat_module.dilation, | qat_module.dilation, | ||||
| qat_module.groups, | qat_module.groups, | ||||
| dtype=output_dtype, | dtype=output_dtype, | ||||
| name=qat_module.name, | |||||
| ) | ) | ||||
| weight = qat_module.weight.astype(qat_module.get_weight_dtype()) | weight = qat_module.weight.astype(qat_module.get_weight_dtype()) | ||||
| qconv.weight = Parameter(weight.numpy()) | |||||
| qconv.weight = Parameter(weight.numpy(), name=qat_module.weight.name) | |||||
| if qat_module.bias is not None: | if qat_module.bias is not None: | ||||
| qconv.bias = Parameter(qat_module.bias.numpy()) | |||||
| qconv.bias = Parameter(qat_module.bias.numpy(), name=qat_module.bias.name) | |||||
| else: | else: | ||||
| qconv.bias = Parameter( | qconv.bias = Parameter( | ||||
| np.zeros(qat_module._infer_bias_shape(), dtype=np.float32) | np.zeros(qat_module._infer_bias_shape(), dtype=np.float32) | ||||
| @@ -33,13 +33,14 @@ class _ConvBnActivation2d(Conv2d): | |||||
| qat_module.conv.dilation, | qat_module.conv.dilation, | ||||
| qat_module.conv.groups, | qat_module.conv.groups, | ||||
| dtype=output_dtype, | dtype=output_dtype, | ||||
| name=qat_module.name, | |||||
| ) | ) | ||||
| w_fold, b_fold = qat_module.fold_weight_bias( | w_fold, b_fold = qat_module.fold_weight_bias( | ||||
| qat_module.bn.running_mean, qat_module.bn.running_var | qat_module.bn.running_mean, qat_module.bn.running_var | ||||
| ) | ) | ||||
| weight = w_fold.astype(qat_module.get_weight_dtype()) | weight = w_fold.astype(qat_module.get_weight_dtype()) | ||||
| qconv.weight = Parameter(weight.numpy()) | |||||
| qconv.bias = Parameter(b_fold.numpy()) | |||||
| qconv.weight = Parameter(weight.numpy(), name=qat_module.conv.weight.name) | |||||
| qconv.bias = Parameter(b_fold.numpy(), name=qat_module.conv.bias.name) | |||||
| return qconv | return qconv | ||||
| @@ -14,8 +14,8 @@ from .module import QuantizedModule | |||||
| class Elemwise(QuantizedModule): | class Elemwise(QuantizedModule): | ||||
| r"""Quantized version of :class:`~.qat.Elemwise`.""" | r"""Quantized version of :class:`~.qat.Elemwise`.""" | ||||
| def __init__(self, method, dtype=None): | |||||
| super().__init__() | |||||
| def __init__(self, method, dtype=None, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.method = "Q" + method | self.method = "Q" + method | ||||
| self.output_dtype = dtype | self.output_dtype = dtype | ||||
| @@ -30,4 +30,6 @@ class Elemwise(QuantizedModule): | |||||
| Return a :class:`~.QuantizedModule` instance converted from a | Return a :class:`~.QuantizedModule` instance converted from a | ||||
| :class:`~.QATModule` instance. | :class:`~.QATModule` instance. | ||||
| """ | """ | ||||
| return cls(qat_module.method, qat_module.get_activation_dtype()) | |||||
| return cls( | |||||
| qat_module.method, qat_module.get_activation_dtype(), name=qat_module.name | |||||
| ) | |||||
| @@ -17,8 +17,8 @@ from .module import QuantizedModule | |||||
| class Linear(QuantizedModule): | class Linear(QuantizedModule): | ||||
| r"""Quantized version of :class:`~.qat.Linear`.""" | r"""Quantized version of :class:`~.qat.Linear`.""" | ||||
| def __init__(self, dtype: np.dtype = None): | |||||
| super().__init__() | |||||
| def __init__(self, dtype: np.dtype = None, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.weight = None | self.weight = None | ||||
| self.bias = None | self.bias = None | ||||
| self.output_dtype = dtype | self.output_dtype = dtype | ||||
| @@ -44,9 +44,9 @@ class Linear(QuantizedModule): | |||||
| :class:`~.QATModule` instance. | :class:`~.QATModule` instance. | ||||
| """ | """ | ||||
| output_dtype = qat_module.get_activation_dtype() | output_dtype = qat_module.get_activation_dtype() | ||||
| qmod = cls(dtype=output_dtype) | |||||
| qmod = cls(dtype=output_dtype, name=qat_module.name) | |||||
| weight = qat_module.weight.astype(qat_module.get_weight_dtype()) | weight = qat_module.weight.astype(qat_module.get_weight_dtype()) | ||||
| qmod.weight = Parameter(weight.numpy()) | |||||
| qmod.weight = Parameter(weight.numpy(), name=qat_module.weight.name) | |||||
| if qat_module.bias is not None: | if qat_module.bias is not None: | ||||
| qmod.bias = Parameter(qat_module.bias.numpy()) | |||||
| qmod.bias = Parameter(qat_module.bias.numpy(), name=qat_module.bias.name) | |||||
| return qmod | return qmod | ||||
| @@ -15,8 +15,8 @@ class QuantStub(QuantizedModule): | |||||
| will convert input to quantized dtype. | will convert input to quantized dtype. | ||||
| """ | """ | ||||
| def __init__(self, dtype=None): | |||||
| super().__init__() | |||||
| def __init__(self, dtype=None, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.output_dtype = dtype | self.output_dtype = dtype | ||||
| def forward(self, inp): | def forward(self, inp): | ||||
| @@ -28,7 +28,7 @@ class QuantStub(QuantizedModule): | |||||
| Return a :class:`~.QuantizedModule` instance converted from a | Return a :class:`~.QuantizedModule` instance converted from a | ||||
| :class:`~.QATModule` instance. | :class:`~.QATModule` instance. | ||||
| """ | """ | ||||
| return cls(qat_module.get_activation_dtype()) | |||||
| return cls(qat_module.get_activation_dtype(), name=qat_module.name) | |||||
| class DequantStub(QuantizedModule): | class DequantStub(QuantizedModule): | ||||
| @@ -46,4 +46,4 @@ class DequantStub(QuantizedModule): | |||||
| Return a :class:`~.QuantizedModule` instance converted from a | Return a :class:`~.QuantizedModule` instance converted from a | ||||
| :class:`~.QATModule` instance. | :class:`~.QATModule` instance. | ||||
| """ | """ | ||||
| return cls() | |||||
| return cls(name=qat_module.name) | |||||
| @@ -17,6 +17,7 @@ import megengine.utils.comp_graph_tools as cgtools | |||||
| from megengine import Parameter, Tensor | from megengine import Parameter, Tensor | ||||
| from megengine.core.tensor import megbrain_graph as G | from megengine.core.tensor import megbrain_graph as G | ||||
| from megengine.jit.tracing import trace | from megengine.jit.tracing import trace | ||||
| from megengine.quantization.quantize import quantize, quantize_qat | |||||
| from megengine.utils.naming import auto_naming | from megengine.utils.naming import auto_naming | ||||
| @@ -29,14 +30,14 @@ def _dump_and_load(func, symbolic, keep_opr_name=True): | |||||
| func.dump( | func.dump( | ||||
| file, | file, | ||||
| optimize_for_inference=False, | optimize_for_inference=False, | ||||
| arg_names="x", | |||||
| arg_names=("x",), | |||||
| keep_opr_name=keep_opr_name, | keep_opr_name=keep_opr_name, | ||||
| keep_var_name=2, | keep_var_name=2, | ||||
| ) | ) | ||||
| file.seek(0) | file.seek(0) | ||||
| *_, outputs = G.load_graph(file) | *_, outputs = G.load_graph(file) | ||||
| op = cgtools.get_oprs_seq(outputs)[-1] | |||||
| return op | |||||
| ops = cgtools.get_oprs_seq(outputs) | |||||
| return ops | |||||
| @pytest.mark.parametrize("symbolic", [False, True]) | @pytest.mark.parametrize("symbolic", [False, True]) | ||||
| @@ -50,7 +51,7 @@ def test_auto_naming(symbolic): | |||||
| return x + x | return x + x | ||||
| m = Simple("simple") | m = Simple("simple") | ||||
| op = _dump_and_load(m, symbolic) | |||||
| op = _dump_and_load(m, symbolic)[-1] | |||||
| assert op.name == "simple.ADD" | assert op.name == "simple.ADD" | ||||
| assert op.outputs[0].name == "simple.ADD" | assert op.outputs[0].name == "simple.ADD" | ||||
| @@ -70,7 +71,7 @@ def test_user_named_tensor(symbolic): | |||||
| m = Simple("simple") | m = Simple("simple") | ||||
| op = _dump_and_load(m, symbolic) | |||||
| op = _dump_and_load(m, symbolic)[-1] | |||||
| assert op.name == "simple.ADD" | assert op.name == "simple.ADD" | ||||
| assert op.outputs[0].name == "o_x" | assert op.outputs[0].name == "o_x" | ||||
| @@ -88,7 +89,7 @@ def test_user_named_param(symbolic): | |||||
| m = Simple("simple") | m = Simple("simple") | ||||
| op = _dump_and_load(m, symbolic) | |||||
| op = _dump_and_load(m, symbolic)[-1] | |||||
| assert op.inputs[0].name == "x" | assert op.inputs[0].name == "x" | ||||
| assert op.inputs[1].name == "simple.k" | assert op.inputs[1].name == "simple.k" | ||||
| @@ -98,7 +99,7 @@ def test_without_module(symbolic): | |||||
| def f(x): | def f(x): | ||||
| return 2 * x | return 2 * x | ||||
| op = _dump_and_load(f, symbolic) | |||||
| op = _dump_and_load(f, symbolic)[-1] | |||||
| assert op.name == "MUL" | assert op.name == "MUL" | ||||
| @@ -116,10 +117,10 @@ def test_with_submodule(symbolic): | |||||
| m = Simple("simple") | m = Simple("simple") | ||||
| op = _dump_and_load(m, symbolic) | |||||
| assert op.name == "simple.linear.ADD" | |||||
| assert op.inputs[0].owner.name == "simple.linear.MatrixMul" | |||||
| assert op.outputs[0].name == "simple.linear.ADD" | |||||
| ops = _dump_and_load(m, symbolic) | |||||
| assert ops[-1].name == "simple.linear.ADD" | |||||
| assert ops[-2].name == "simple.linear.MatrixMul" | |||||
| assert ops[-1].outputs[0].name == "simple.linear.ADD" | |||||
| @pytest.mark.parametrize("symbolic", [False, True]) | @pytest.mark.parametrize("symbolic", [False, True]) | ||||
| @@ -136,10 +137,10 @@ def test_named_submodule(symbolic): | |||||
| m = Simple("simple") | m = Simple("simple") | ||||
| op = _dump_and_load(m, symbolic) | |||||
| assert op.name == "simple.x.ADD" | |||||
| assert op.inputs[0].owner.name == "simple.x.MatrixMul" | |||||
| assert op.outputs[0].name == "simple.x.ADD" | |||||
| ops = _dump_and_load(m, symbolic) | |||||
| assert ops[-1].name == "simple.x.ADD" | |||||
| assert ops[-2].name == "simple.x.MatrixMul" | |||||
| assert ops[-1].outputs[0].name == "simple.x.ADD" | |||||
| @pytest.mark.parametrize("symbolic", [False, True]) | @pytest.mark.parametrize("symbolic", [False, True]) | ||||
| @@ -156,14 +157,111 @@ def test_with_same_operators(symbolic): | |||||
| m = Simple("simple") | m = Simple("simple") | ||||
| op = _dump_and_load(m, symbolic) | |||||
| assert op.name == "simple.RELU[1]" | |||||
| assert op.inputs[0].owner.name == "simple.RELU[0]" | |||||
| ops = _dump_and_load(m, symbolic) | |||||
| assert ops[-1].name == "simple.RELU[1]" | |||||
| assert ops[-2].name == "simple.RELU[0]" | |||||
| def test_not_keep_opr_name(): | def test_not_keep_opr_name(): | ||||
| def f(x): | def f(x): | ||||
| return 2 * x | return 2 * x | ||||
| op = _dump_and_load(f, True, False) | |||||
| op = _dump_and_load(f, True, False)[-1] | |||||
| assert op.name == "MUL(x,2[2])[4]" | assert op.name == "MUL(x,2[2])[4]" | ||||
| @pytest.mark.parametrize("symbolic", [False, True]) | |||||
| def test_quantized_module_auto_naming(symbolic): | |||||
| class Simple(M.Module): | |||||
| def __init__(self, name): | |||||
| super().__init__(name=name) | |||||
| self.quant = M.QuantStub() | |||||
| self.linear = M.Linear(3, 3, bias=True) | |||||
| self.dequant = M.DequantStub() | |||||
| def forward(self, x): | |||||
| out = self.quant(x) | |||||
| out = self.linear(out) | |||||
| out = self.dequant(out) | |||||
| return out | |||||
| m = Simple("simple") | |||||
| quantize_qat(m) | |||||
| quantize(m) | |||||
| m.eval() | |||||
| ops = _dump_and_load(m, symbolic) | |||||
| ops_name = ( | |||||
| "x", | |||||
| "simple.quant.TypeCvt", | |||||
| "simple.linear.MatrixMul", | |||||
| "simple.linear.ADD", | |||||
| "simple.linear.TypeCvt", | |||||
| "simple.dequant.TypeCvt", | |||||
| ) | |||||
| for op, name in zip(ops, ops_name): | |||||
| assert op.name == name | |||||
| @pytest.mark.parametrize("symbolic", [False, True]) | |||||
| def test_quantized_module_user_naming(symbolic): | |||||
| class Simple(M.Module): | |||||
| def __init__(self, name): | |||||
| super().__init__(name=name) | |||||
| self.quant = M.QuantStub() | |||||
| self.linear = M.Linear(3, 3, bias=True, name="user-linear") | |||||
| self.dequant = M.DequantStub() | |||||
| def forward(self, x): | |||||
| out = self.quant(x) | |||||
| out = self.linear(out) | |||||
| out = self.dequant(out) | |||||
| return out | |||||
| m = Simple("simple") | |||||
| quantize_qat(m) | |||||
| quantize(m) | |||||
| m.eval() | |||||
| ops = _dump_and_load(m, symbolic) | |||||
| ops_name = ( | |||||
| "x", | |||||
| "simple.quant.TypeCvt", | |||||
| "simple.user-linear.MatrixMul", | |||||
| "simple.user-linear.ADD", | |||||
| "simple.user-linear.TypeCvt", | |||||
| "simple.dequant.TypeCvt", | |||||
| ) | |||||
| for op, name in zip(ops, ops_name): | |||||
| assert op.name == name | |||||
| @pytest.mark.parametrize("symbolic", [False, True]) | |||||
| def test_quantized_module_user_naming_param(symbolic): | |||||
| class Simple(M.Module): | |||||
| def __init__(self, name): | |||||
| super().__init__(name=name) | |||||
| self.quant = M.QuantStub() | |||||
| self.linear = M.Linear(3, 3, bias=True) | |||||
| self.dequant = M.DequantStub() | |||||
| self.linear.weight.name = "user-weight" | |||||
| self.linear.bias.name = "user-bias" | |||||
| def forward(self, x): | |||||
| out = self.quant(x) | |||||
| out = self.linear(out) | |||||
| out = self.dequant(out) | |||||
| return out | |||||
| m = Simple("simple") | |||||
| quantize_qat(m) | |||||
| quantize(m) | |||||
| m.eval() | |||||
| ops = _dump_and_load(m, symbolic) | |||||
| (matrix_mul_op,) = [op for op in ops if op.name == "simple.linear.MatrixMul"] | |||||
| for var in matrix_mul_op.inputs: | |||||
| assert var.name in ("simple.quant.TypeCvt", "simple.linear.user-weight") | |||||
| # BUG bias' name does not meet expectations because of astype operator after quantization | |||||