| @@ -35,6 +35,10 @@ class _BatchNorm(Module): | |||
| self.track_running_stats = track_running_stats | |||
| self._track_running_stats_saved = track_running_stats | |||
| self.freeze = freeze | |||
| if self.freeze: | |||
| assert ( | |||
| self._track_running_stats_saved | |||
| ), "track_running_stats must be initilized to True if freeze is True" | |||
| tshape = (1, self.num_features, 1, 1) | |||
| if self.affine: | |||
| self.weight = Parameter(np.ones(tshape, dtype=np.float32)) | |||
| @@ -84,10 +88,24 @@ class _BatchNorm(Module): | |||
| inp = inp.reshape(new_shape) | |||
| if self.freeze and self.training and self._track_running_stats_saved: | |||
| scale = self.weight * (self.running_var + self.eps) ** (-0.5) | |||
| bias = self.bias - self.running_mean * scale | |||
| return inp * scale.detach() + bias.detach() | |||
| _weight = self.weight | |||
| _bias = self.bias | |||
| if self.freeze: | |||
| if _weight is not None: | |||
| _weight = _weight.detach() | |||
| if _bias is not None: | |||
| _bias = _bias.detach() | |||
| # Need to expand to elementwise operations here | |||
| # see MGB_IMPL_OPR_GRAD(BatchNormForward) in src/opr/impl/dnn/batch_norm.cpp | |||
| scale = (self.running_var + self.eps) ** (-0.5) | |||
| if _weight is not None: | |||
| scale *= _weight | |||
| bias = -self.running_mean * scale | |||
| if _bias is not None: | |||
| bias += _bias | |||
| return inp * scale + bias | |||
| if self.training and self.track_running_stats: | |||
| exponential_average_factor = self.momentum | |||
| @@ -98,8 +116,8 @@ class _BatchNorm(Module): | |||
| inp, | |||
| self.running_mean if self.track_running_stats else None, | |||
| self.running_var if self.track_running_stats else None, | |||
| self.weight, | |||
| self.bias, | |||
| _weight, | |||
| _bias, | |||
| training=self.training | |||
| or ((self.running_mean is None) and (self.running_var is None)), | |||
| momentum=exponential_average_factor, | |||
| @@ -121,7 +139,7 @@ class _BatchNorm(Module): | |||
| class SyncBatchNorm(_BatchNorm): | |||
| r""" | |||
| Applies Synchronization Batch Normalization. | |||
| Applies Synchronized Batch Normalization for distributed training. | |||
| """ | |||
| def __init__( | |||
| @@ -169,15 +187,25 @@ class SyncBatchNorm(_BatchNorm): | |||
| else: | |||
| exponential_average_factor = 0.0 # useless | |||
| _weight = self.weight | |||
| _bias = self.bias | |||
| if self.freeze: | |||
| if _weight is not None: | |||
| _weight = _weight.detach() | |||
| if _bias is not None: | |||
| _bias = _bias.detach() | |||
| output = sync_batch_norm( | |||
| inp, | |||
| self.running_mean, | |||
| self.running_var, | |||
| self.weight, | |||
| self.bias, | |||
| self.training or not self.track_running_stats, | |||
| exponential_average_factor, | |||
| self.eps, | |||
| _weight, | |||
| _bias, | |||
| training=(self.training and not self.freeze) | |||
| or ((self.running_mean is None) and (self.running_var is None)), | |||
| momentum=exponential_average_factor, | |||
| eps=self.eps, | |||
| group=self.group, | |||
| ) | |||
| @@ -257,8 +285,7 @@ class BatchNorm2d(_BatchNorm): | |||
| :param freeze: when set to True, this module does not update the | |||
| running mean and variance, and uses the running mean and variance instead of | |||
| the batch mean and batch variance to normalize the input. The parameter takes effect | |||
| only when the module is initilized with track_running_stats as True and | |||
| the module is in training mode. | |||
| only when the module is initilized with track_running_stats as True. | |||
| Default: False | |||
| Examples: | |||
| @@ -11,15 +11,23 @@ import pytest | |||
| import megengine | |||
| import megengine.autodiff as ad | |||
| import megengine.distributed as dist | |||
| import megengine.functional as F | |||
| import megengine.optimizer as optimizer | |||
| from megengine import Parameter, tensor | |||
| from megengine.distributed.helper import get_device_count_by_fork | |||
| from megengine.jit import trace | |||
| from megengine.module import BatchNorm2d, Module | |||
| from megengine.module import BatchNorm2d, Module, SyncBatchNorm | |||
| def test_frozen_bn(): | |||
| def run_frozen_bn(BNModule, use_trace=False, use_symbolic=False): | |||
| nchannel = 3 | |||
| m = BatchNorm2d(nchannel, freeze=True) | |||
| m = BNModule(nchannel, freeze=True) | |||
| var = 4.0 | |||
| bias = 1.0 | |||
| shape = (1, nchannel, 1, 1) | |||
| m.running_var[...] = var * F.ones(shape) | |||
| m.running_mean[...] = bias * F.ones(shape) | |||
| saved_var = m.running_var.numpy() | |||
| saved_mean = m.running_mean.numpy() | |||
| @@ -31,16 +39,45 @@ def test_frozen_bn(): | |||
| optim.clear_grad() | |||
| data = np.random.random((6, nchannel, 2, 2)).astype("float32") | |||
| with gm: | |||
| loss = m(data).mean() | |||
| gm.backward(loss) | |||
| optim.step() | |||
| np.testing.assert_equal(m.running_var.numpy(), saved_var) | |||
| np.testing.assert_equal(m.running_mean.numpy(), saved_mean) | |||
| np.testing.assert_equal(m.weight.numpy(), saved_wt) | |||
| np.testing.assert_equal(m.bias.numpy(), saved_bias) | |||
| np.testing.assert_almost_equal(loss.numpy(), data.mean(), 5) | |||
| def train_fn(d): | |||
| for _ in range(3): | |||
| with gm: | |||
| loss = m(d).mean() | |||
| gm.backward(loss) | |||
| optim.step() | |||
| return loss | |||
| if use_trace: | |||
| train_fn = trace(train_fn, symbolic=use_symbolic) | |||
| for _ in range(3): | |||
| loss = train_fn(megengine.Tensor(data)) | |||
| np.testing.assert_equal(m.running_var.numpy(), saved_var) | |||
| np.testing.assert_equal(m.running_mean.numpy(), saved_mean) | |||
| np.testing.assert_equal(m.weight.numpy(), saved_wt) | |||
| np.testing.assert_equal(m.bias.numpy(), saved_bias) | |||
| np.testing.assert_almost_equal( | |||
| loss.numpy(), ((data - bias) / np.sqrt(var)).mean(), 5 | |||
| ) | |||
| def test_frozen_bn(): | |||
| run_frozen_bn(BatchNorm2d) | |||
| run_frozen_bn(BatchNorm2d, True, False) | |||
| run_frozen_bn(BatchNorm2d, True, True) | |||
| @pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") | |||
| @pytest.mark.isolated_distributed | |||
| def test_frozen_synced_bn(): | |||
| @dist.launcher(n_gpus=2) | |||
| def worker(): | |||
| run_frozen_bn(SyncBatchNorm) | |||
| run_frozen_bn(SyncBatchNorm, True, False) | |||
| run_frozen_bn(SyncBatchNorm, True, True) | |||
| worker() | |||
| def test_bn_no_track_stat(): | |||
| @@ -112,3 +149,11 @@ def test_trace_bn_forward_twice(): | |||
| x = np.ones((1, 1, 32, 32), dtype=np.float32) | |||
| y = train_bn(x, net=Simple()) | |||
| np.testing.assert_equal(y.numpy(), 0) | |||
| # https://github.com/MegEngine/MegEngine/issues/145 | |||
| def test_frozen_bn_no_affine(): | |||
| nchannel = 3 | |||
| m = BatchNorm2d(nchannel, freeze=True, affine=False) | |||
| data = megengine.Tensor(np.random.random((6, nchannel, 2, 2)).astype("float32")) | |||
| m(data).numpy() | |||