| @@ -99,21 +99,9 @@ class MinMaxObserver(Observer): | |||||
| def __init__(self, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8"): | def __init__(self, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8"): | ||||
| super().__init__(dtype) | super().__init__(dtype) | ||||
| self.mode = mode | self.mode = mode | ||||
| self.min_val = Buffer(0.0, dtype=np.float32) | |||||
| self.max_val = Buffer(0.0, dtype=np.float32) | |||||
| self.min_val = Buffer(np.finfo(np.float32).max, dtype=np.float32) | |||||
| self.max_val = Buffer(np.finfo(np.float32).min, dtype=np.float32) | |||||
| self.scale_limit = eps | self.scale_limit = eps | ||||
| # flag is used by cond_take, first time will be first flag, and after will be set as not_flag | |||||
| self.first_flag = Buffer(np.array([1, 0], dtype=np.int32)) | |||||
| self.not_flag = Buffer(np.array([0, 1], dtype=np.int32)) | |||||
| def set_min_max(self, tmp_min, tmp_max): | |||||
| # FIXME: cond_take will destory shape, use reshape to reset shape | |||||
| tmp_min = tmp_min.reshape(1) | |||||
| tmp_max = tmp_max.reshape(1) | |||||
| F.add_update(self.min_val, tmp_min, alpha=0.0, beta=1.0, bias=0.0) | |||||
| F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0) | |||||
| F.add_update(self.first_flag, self.not_flag, alpha=0.0, beta=1.0, bias=0.0) | |||||
| def _calculate_qparams(self, inp_min_val, inp_max_val): | def _calculate_qparams(self, inp_min_val, inp_max_val): | ||||
| min_val = F.minimum(0.0, inp_min_val) | min_val = F.minimum(0.0, inp_min_val) | ||||
| @@ -144,13 +132,20 @@ class MinMaxObserver(Observer): | |||||
| # stop gradient | # stop gradient | ||||
| x = F.zero_grad(x_orig) | x = F.zero_grad(x_orig) | ||||
| # find max and min | # find max and min | ||||
| tmp_min, _ = F.cond_take( | |||||
| self.first_flag, F.concat([x.min(), F.minimum(self.min_val, x.min())]) | |||||
| F.add_update( | |||||
| self.min_val, | |||||
| F.minimum(self.min_val, x.min()), | |||||
| alpha=0.0, | |||||
| beta=1.0, | |||||
| bias=0.0, | |||||
| ) | ) | ||||
| tmp_max, _ = F.cond_take( | |||||
| self.first_flag, F.concat([x.max(), F.maximum(self.max_val, x.max())]) | |||||
| F.add_update( | |||||
| self.max_val, | |||||
| F.maximum(self.max_val, x.max()), | |||||
| alpha=0.0, | |||||
| beta=1.0, | |||||
| bias=0.0, | |||||
| ) | ) | ||||
| self.set_min_max(tmp_min, tmp_max) | |||||
| return x_orig | return x_orig | ||||
| @@ -160,6 +155,7 @@ class ExponentialMovingAverageObserver(MinMaxObserver): | |||||
| ): | ): | ||||
| super().__init__(mode, eps, dtype) | super().__init__(mode, eps, dtype) | ||||
| self.momentum = Buffer(momentum) | self.momentum = Buffer(momentum) | ||||
| self.runtime_momentum = Buffer(0.0) | |||||
| def set_momentum(self, momentum): | def set_momentum(self, momentum): | ||||
| self.momentum.set_value(momentum) | self.momentum.set_value(momentum) | ||||
| @@ -169,25 +165,19 @@ class ExponentialMovingAverageObserver(MinMaxObserver): | |||||
| # stop gradient | # stop gradient | ||||
| x = F.zero_grad(x_orig) | x = F.zero_grad(x_orig) | ||||
| # Exponential Moving Average | # Exponential Moving Average | ||||
| tmp_min, _ = F.cond_take( | |||||
| self.first_flag, | |||||
| F.concat( | |||||
| [ | |||||
| x.min(), | |||||
| self.momentum * self.min_val + (1 - self.momentum) * x.min(), | |||||
| ] | |||||
| ), | |||||
| tmp_min = ( | |||||
| self.min_val * self.runtime_momentum | |||||
| + (1 - self.runtime_momentum) * x.min() | |||||
| ) | |||||
| tmp_max = ( | |||||
| self.max_val * self.runtime_momentum | |||||
| + (1 - self.runtime_momentum) * x.max() | |||||
| ) | ) | ||||
| tmp_max, _ = F.cond_take( | |||||
| self.first_flag, | |||||
| F.concat( | |||||
| [ | |||||
| x.max(), | |||||
| self.momentum * self.max_val + (1 - self.momentum) * x.max(), | |||||
| ] | |||||
| ), | |||||
| F.add_update(self.min_val, tmp_min, alpha=0.0, beta=1.0, bias=0.0) | |||||
| F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0) | |||||
| F.add_update( | |||||
| self.runtime_momentum, self.momentum, alpha=0.0, beta=1.0, bias=0.0 | |||||
| ) | ) | ||||
| self.set_min_max(tmp_min, tmp_max) | |||||
| return x_orig | return x_orig | ||||