diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index c24bd5b9c4..fd6b64e368 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -845,7 +845,7 @@ class Conv2dBnWithoutFoldQuant(Cell): channel_axis=channel_axis, num_channels=out_channels, quant_dtype=quant_dtype) - self.batchnorm = BatchNorm2d(out_channels, eps=eps, momentum=1-momentum) + self.batchnorm = BatchNorm2d(out_channels, eps=eps, momentum=momentum) def construct(self, x): weight = self.fake_quant_weight(self.weight)