|
|
|
@@ -5,7 +5,7 @@ |
|
|
|
# Unless required by applicable law or agreed to in writing, |
|
|
|
# software distributed under the License is distributed on an |
|
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
from ...functional import add_update, ones, relu, sqrt, sum, zeros |
|
|
|
from ...functional import ones, relu, sqrt, sum, zeros |
|
|
|
from ...quantization.utils import fake_quant_bias |
|
|
|
from .. import conv_bn as Float |
|
|
|
from .module import QATModule |
|
|
|
@@ -76,18 +76,10 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): |
|
|
|
bn_var.detach() * num_elements_per_channel / (num_elements_per_channel - 1) |
|
|
|
) |
|
|
|
exponential_average_factor = 1 - self.bn.momentum |
|
|
|
add_update( |
|
|
|
self.bn.running_mean, |
|
|
|
delta=bn_mean, |
|
|
|
alpha=1 - exponential_average_factor, |
|
|
|
beta=exponential_average_factor, |
|
|
|
) |
|
|
|
add_update( |
|
|
|
self.bn.running_var, |
|
|
|
delta=bn_var, |
|
|
|
alpha=1 - exponential_average_factor, |
|
|
|
beta=exponential_average_factor, |
|
|
|
) |
|
|
|
self.bn.running_mean *= self.bn.momentum |
|
|
|
self.bn.running_mean += exponential_average_factor * bn_mean |
|
|
|
self.bn.running_var *= self.bn.momentum |
|
|
|
self.bn.running_var += exponential_average_factor * bn_var |
|
|
|
|
|
|
|
def calc_conv_bn_qat(self, inp, approx=True): |
|
|
|
if self.training and not approx: |
|
|
|
|