GitOrigin-RevId: c915c843b8
tags/v1.1.0
| @@ -15,6 +15,7 @@ from .qconfig import ( | |||
| ema_fakequant_qconfig, | |||
| ema_lowbit_fakequant_qconfig, | |||
| min_max_fakequant_qconfig, | |||
| sync_ema_fakequant_qconfig, | |||
| tqt_quant_qconfig, | |||
| ) | |||
| from .utils import QuantMode | |||
| @@ -12,6 +12,8 @@ import numpy as np | |||
| from .. import functional as F | |||
| from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype | |||
| from ..distributed import WORLD, get_rank, is_distributed | |||
| from ..functional.distributed import all_reduce_max, all_reduce_min | |||
| from ..module import Module | |||
| from ..tensor import Tensor | |||
| from .utils import QuantMode, Round, get_qparam_dict | |||
| @@ -123,6 +125,21 @@ class MinMaxObserver(Observer): | |||
| return x_orig | |||
| class SyncMinMaxObserver(MinMaxObserver): | |||
| def forward(self, x_orig): | |||
| if self.enable: | |||
| x = x_orig.detach() | |||
| if is_distributed(): | |||
| min_x = all_reduce_min(x.min(), WORLD) | |||
| max_x = all_reduce_max(x.max(), WORLD) | |||
| else: | |||
| min_x = x.min() | |||
| max_x = x.max() | |||
| self.min_val._reset(F.minimum(self.min_val, min_x)) | |||
| self.max_val._reset(F.maximum(self.max_val, max_x)) | |||
| return x_orig | |||
| class ExponentialMovingAverageObserver(MinMaxObserver): | |||
| def __init__( | |||
| self, | |||
| @@ -157,6 +174,28 @@ class ExponentialMovingAverageObserver(MinMaxObserver): | |||
| return x_orig | |||
| class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver): | |||
| def forward(self, x_orig): | |||
| if self.enabled: | |||
| x = x_orig.detach() | |||
| if is_distributed: | |||
| min_x = all_reduce_min(x.min(), WORLD) | |||
| max_x = all_reduce_max(x.max(), WORLD) | |||
| else: | |||
| min_x = x.min() | |||
| max_x = x.max() | |||
| self.min_val._reset( | |||
| self.min_val * self.runtime_momentum | |||
| + (1 - self.runtime_momentum) * min_x | |||
| ) | |||
| self.max_val._reset( | |||
| self.max_val * self.runtime_momentum | |||
| + (1 - self.runtime_momentum) * max_x | |||
| ) | |||
| self.runtime_momentum = self.momentum | |||
| return x_orig | |||
| class HistogramObserver(MinMaxObserver): | |||
| def __init__( | |||
| self, | |||
| @@ -13,6 +13,8 @@ from .observer import ( | |||
| ExponentialMovingAverageObserver, | |||
| HistogramObserver, | |||
| MinMaxObserver, | |||
| SyncExponentialMovingAverageObserver, | |||
| SyncMinMaxObserver, | |||
| ) | |||
| @@ -92,6 +94,15 @@ ema_fakequant_qconfig = QConfig( | |||
| act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||
| ) | |||
| sync_ema_fakequant_qconfig = QConfig( | |||
| weight_observer=partial(SyncMinMaxObserver, dtype="qint8", narrow_range=True), | |||
| act_observer=partial( | |||
| SyncExponentialMovingAverageObserver, dtype="qint8", narrow_range=False | |||
| ), | |||
| weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||
| act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||
| ) | |||
| ema_lowbit_fakequant_qconfig = QConfig( | |||
| weight_observer=partial(MinMaxObserver, dtype="qint4", narrow_range=False), | |||
| act_observer=partial( | |||
| @@ -143,7 +143,6 @@ def test_batchnorm(): | |||
| @pytest.mark.skipif( | |||
| platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||
| ) | |||
| @pytest.mark.isolated_distributed | |||
| def test_syncbn1d(): | |||
| nr_chan = 8 | |||
| data_shape = (3, nr_chan, 4) | |||
| @@ -234,7 +233,6 @@ def test_batchnorm2d(): | |||
| @pytest.mark.skipif( | |||
| platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||
| ) | |||
| @pytest.mark.isolated_distributed | |||
| def test_syncbn2d(): | |||
| nr_chan = 8 | |||
| data_shape = (3, nr_chan, 16, 16) | |||
| @@ -305,7 +303,6 @@ def test_batchnorm_no_stats(): | |||
| @pytest.mark.skipif( | |||
| platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||
| ) | |||
| @pytest.mark.isolated_distributed | |||
| def test_syncbn_no_stats(): | |||
| nr_chan = 8 | |||
| data_shape = (3, nr_chan, 4) | |||
| @@ -354,7 +351,6 @@ def test_batchnorm2d_no_stats(): | |||
| @pytest.mark.skipif( | |||
| platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||
| ) | |||
| @pytest.mark.isolated_distributed | |||
| def test_syncbn2d_no_stats(): | |||
| nr_chan = 8 | |||
| data_shape = (3, nr_chan, 16, 16) | |||
| @@ -0,0 +1,52 @@ | |||
| import multiprocessing as mp | |||
| import platform | |||
| import numpy as np | |||
| import pytest | |||
| import megengine as mge | |||
| import megengine.distributed as dist | |||
| import megengine.quantization.observer as ob | |||
| from megengine.distributed.helper import get_device_count_by_fork | |||
| def test_min_max_observer(): | |||
| x = np.random.rand(3, 3, 3, 3).astype("float32") | |||
| np_min, np_max = x.min(), x.max() | |||
| x = mge.tensor(x) | |||
| m = ob.MinMaxObserver() | |||
| m(x) | |||
| assert m.min_val == np_min and m.max_val == np_max | |||
| @pytest.mark.skipif( | |||
| platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||
| ) | |||
| @pytest.mark.skipif( | |||
| platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" | |||
| ) | |||
| @pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") | |||
| @pytest.mark.isolated_distributed | |||
| def test_sync_min_max_observer(): | |||
| x = np.random.rand(6, 3, 3, 3).astype("float32") | |||
| np_min, np_max = x.min(), x.max() | |||
| world_size = 2 | |||
| port = dist.get_free_ports(1)[0] | |||
| server = dist.Server(port) | |||
| def worker(rank, slc): | |||
| dist.init_process_group("localhost", port, world_size, rank, rank) | |||
| m = ob.SyncMinMaxObserver() | |||
| y = mge.tensor(x[slc]) | |||
| m(y) | |||
| assert m.min_val == np_min and m.max_val == np_max | |||
| procs = [] | |||
| for rank in range(world_size): | |||
| slc = slice(rank * 3, (rank + 1) * 3) | |||
| p = mp.Process(target=worker, args=(rank, slc,), daemon=True) | |||
| p.start() | |||
| procs.append(p) | |||
| for p in procs: | |||
| p.join(20) | |||
| assert p.exitcode == 0 | |||