|
|
|
@@ -141,8 +141,7 @@ class FakeQuantWithMinMax(Cell): |
|
|
|
out_channels=1, |
|
|
|
quant_delay=0, |
|
|
|
symmetric=False, |
|
|
|
narrow_range=False, |
|
|
|
training=True): |
|
|
|
narrow_range=False): |
|
|
|
"""init FakeQuantWithMinMax layer""" |
|
|
|
super(FakeQuantWithMinMax, self).__init__() |
|
|
|
self.min_init = min_init |
|
|
|
@@ -156,7 +155,6 @@ class FakeQuantWithMinMax(Cell): |
|
|
|
self.quant_delay = quant_delay |
|
|
|
self.symmetric = symmetric |
|
|
|
self.narrow_range = narrow_range |
|
|
|
self.training = training |
|
|
|
self.is_ascend = context.get_context('device_target') == "Ascend" |
|
|
|
|
|
|
|
# init tensor min and max for fake quant op |
|
|
|
@@ -190,7 +188,7 @@ class FakeQuantWithMinMax(Cell): |
|
|
|
symmetric=self.symmetric, |
|
|
|
narrow_range=self.narrow_range, |
|
|
|
training=self.training) |
|
|
|
if self.ema: |
|
|
|
if self.training: |
|
|
|
self.ema_update = ema_fun(num_bits=self.num_bits, |
|
|
|
ema=self.ema, |
|
|
|
ema_decay=self.ema_decay, |
|
|
|
@@ -206,7 +204,7 @@ class FakeQuantWithMinMax(Cell): |
|
|
|
return s |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
if self.ema and self.is_ascend: |
|
|
|
if self.is_ascend and self.training: |
|
|
|
min_up, max_up = self.ema_update(x, self.minq, self.maxq) |
|
|
|
out = self.fake_quant(x, min_up, max_up) |
|
|
|
P.Assign()(self.minq, min_up) |
|
|
|
|