| @@ -35,6 +35,10 @@ class _BatchNorm(Module): | |||||
| self.track_running_stats = track_running_stats | self.track_running_stats = track_running_stats | ||||
| self._track_running_stats_saved = track_running_stats | self._track_running_stats_saved = track_running_stats | ||||
| self.freeze = freeze | self.freeze = freeze | ||||
| if self.freeze: | |||||
| assert ( | |||||
| self._track_running_stats_saved | |||||
| ), "track_running_stats must be initilized to True if freeze is True" | |||||
| tshape = (1, self.num_features, 1, 1) | tshape = (1, self.num_features, 1, 1) | ||||
| if self.affine: | if self.affine: | ||||
| self.weight = Parameter(np.ones(tshape, dtype=np.float32)) | self.weight = Parameter(np.ones(tshape, dtype=np.float32)) | ||||
| @@ -84,10 +88,24 @@ class _BatchNorm(Module): | |||||
| inp = inp.reshape(new_shape) | inp = inp.reshape(new_shape) | ||||
| if self.freeze and self.training and self._track_running_stats_saved: | |||||
| scale = self.weight * (self.running_var + self.eps) ** (-0.5) | |||||
| bias = self.bias - self.running_mean * scale | |||||
| return inp * scale.detach() + bias.detach() | |||||
| _weight = self.weight | |||||
| _bias = self.bias | |||||
| if self.freeze: | |||||
| if _weight is not None: | |||||
| _weight = _weight.detach() | |||||
| if _bias is not None: | |||||
| _bias = _bias.detach() | |||||
| # Need to expand to elementwise operations here | |||||
| # see MGB_IMPL_OPR_GRAD(BatchNormForward) in src/opr/impl/dnn/batch_norm.cpp | |||||
| scale = (self.running_var + self.eps) ** (-0.5) | |||||
| if _weight is not None: | |||||
| scale *= _weight | |||||
| bias = -self.running_mean * scale | |||||
| if _bias is not None: | |||||
| bias += _bias | |||||
| return inp * scale + bias | |||||
| if self.training and self.track_running_stats: | if self.training and self.track_running_stats: | ||||
| exponential_average_factor = self.momentum | exponential_average_factor = self.momentum | ||||
| @@ -98,8 +116,8 @@ class _BatchNorm(Module): | |||||
| inp, | inp, | ||||
| self.running_mean if self.track_running_stats else None, | self.running_mean if self.track_running_stats else None, | ||||
| self.running_var if self.track_running_stats else None, | self.running_var if self.track_running_stats else None, | ||||
| self.weight, | |||||
| self.bias, | |||||
| _weight, | |||||
| _bias, | |||||
| training=self.training | training=self.training | ||||
| or ((self.running_mean is None) and (self.running_var is None)), | or ((self.running_mean is None) and (self.running_var is None)), | ||||
| momentum=exponential_average_factor, | momentum=exponential_average_factor, | ||||
| @@ -121,7 +139,7 @@ class _BatchNorm(Module): | |||||
| class SyncBatchNorm(_BatchNorm): | class SyncBatchNorm(_BatchNorm): | ||||
| r""" | r""" | ||||
| Applies Synchronization Batch Normalization. | |||||
| Applies Synchronized Batch Normalization for distributed training. | |||||
| """ | """ | ||||
| def __init__( | def __init__( | ||||
| @@ -169,15 +187,25 @@ class SyncBatchNorm(_BatchNorm): | |||||
| else: | else: | ||||
| exponential_average_factor = 0.0 # useless | exponential_average_factor = 0.0 # useless | ||||
| _weight = self.weight | |||||
| _bias = self.bias | |||||
| if self.freeze: | |||||
| if _weight is not None: | |||||
| _weight = _weight.detach() | |||||
| if _bias is not None: | |||||
| _bias = _bias.detach() | |||||
| output = sync_batch_norm( | output = sync_batch_norm( | ||||
| inp, | inp, | ||||
| self.running_mean, | self.running_mean, | ||||
| self.running_var, | self.running_var, | ||||
| self.weight, | |||||
| self.bias, | |||||
| self.training or not self.track_running_stats, | |||||
| exponential_average_factor, | |||||
| self.eps, | |||||
| _weight, | |||||
| _bias, | |||||
| training=(self.training and not self.freeze) | |||||
| or ((self.running_mean is None) and (self.running_var is None)), | |||||
| momentum=exponential_average_factor, | |||||
| eps=self.eps, | |||||
| group=self.group, | group=self.group, | ||||
| ) | ) | ||||
| @@ -257,8 +285,7 @@ class BatchNorm2d(_BatchNorm): | |||||
| :param freeze: when set to True, this module does not update the | :param freeze: when set to True, this module does not update the | ||||
| running mean and variance, and uses the running mean and variance instead of | running mean and variance, and uses the running mean and variance instead of | ||||
| the batch mean and batch variance to normalize the input. The parameter takes effect | the batch mean and batch variance to normalize the input. The parameter takes effect | ||||
| only when the module is initilized with track_running_stats as True and | |||||
| the module is in training mode. | |||||
| only when the module is initilized with track_running_stats as True. | |||||
| Default: False | Default: False | ||||
| Examples: | Examples: | ||||
| @@ -11,15 +11,23 @@ import pytest | |||||
| import megengine | import megengine | ||||
| import megengine.autodiff as ad | import megengine.autodiff as ad | ||||
| import megengine.distributed as dist | |||||
| import megengine.functional as F | |||||
| import megengine.optimizer as optimizer | import megengine.optimizer as optimizer | ||||
| from megengine import Parameter, tensor | from megengine import Parameter, tensor | ||||
| from megengine.distributed.helper import get_device_count_by_fork | |||||
| from megengine.jit import trace | from megengine.jit import trace | ||||
| from megengine.module import BatchNorm2d, Module | |||||
| from megengine.module import BatchNorm2d, Module, SyncBatchNorm | |||||
| def test_frozen_bn(): | |||||
| def run_frozen_bn(BNModule, use_trace=False, use_symbolic=False): | |||||
| nchannel = 3 | nchannel = 3 | ||||
| m = BatchNorm2d(nchannel, freeze=True) | |||||
| m = BNModule(nchannel, freeze=True) | |||||
| var = 4.0 | |||||
| bias = 1.0 | |||||
| shape = (1, nchannel, 1, 1) | |||||
| m.running_var[...] = var * F.ones(shape) | |||||
| m.running_mean[...] = bias * F.ones(shape) | |||||
| saved_var = m.running_var.numpy() | saved_var = m.running_var.numpy() | ||||
| saved_mean = m.running_mean.numpy() | saved_mean = m.running_mean.numpy() | ||||
| @@ -31,16 +39,45 @@ def test_frozen_bn(): | |||||
| optim.clear_grad() | optim.clear_grad() | ||||
| data = np.random.random((6, nchannel, 2, 2)).astype("float32") | data = np.random.random((6, nchannel, 2, 2)).astype("float32") | ||||
| with gm: | |||||
| loss = m(data).mean() | |||||
| gm.backward(loss) | |||||
| optim.step() | |||||
| np.testing.assert_equal(m.running_var.numpy(), saved_var) | |||||
| np.testing.assert_equal(m.running_mean.numpy(), saved_mean) | |||||
| np.testing.assert_equal(m.weight.numpy(), saved_wt) | |||||
| np.testing.assert_equal(m.bias.numpy(), saved_bias) | |||||
| np.testing.assert_almost_equal(loss.numpy(), data.mean(), 5) | |||||
| def train_fn(d): | |||||
| for _ in range(3): | |||||
| with gm: | |||||
| loss = m(d).mean() | |||||
| gm.backward(loss) | |||||
| optim.step() | |||||
| return loss | |||||
| if use_trace: | |||||
| train_fn = trace(train_fn, symbolic=use_symbolic) | |||||
| for _ in range(3): | |||||
| loss = train_fn(megengine.Tensor(data)) | |||||
| np.testing.assert_equal(m.running_var.numpy(), saved_var) | |||||
| np.testing.assert_equal(m.running_mean.numpy(), saved_mean) | |||||
| np.testing.assert_equal(m.weight.numpy(), saved_wt) | |||||
| np.testing.assert_equal(m.bias.numpy(), saved_bias) | |||||
| np.testing.assert_almost_equal( | |||||
| loss.numpy(), ((data - bias) / np.sqrt(var)).mean(), 5 | |||||
| ) | |||||
| def test_frozen_bn(): | |||||
| run_frozen_bn(BatchNorm2d) | |||||
| run_frozen_bn(BatchNorm2d, True, False) | |||||
| run_frozen_bn(BatchNorm2d, True, True) | |||||
| @pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") | |||||
| @pytest.mark.isolated_distributed | |||||
| def test_frozen_synced_bn(): | |||||
| @dist.launcher(n_gpus=2) | |||||
| def worker(): | |||||
| run_frozen_bn(SyncBatchNorm) | |||||
| run_frozen_bn(SyncBatchNorm, True, False) | |||||
| run_frozen_bn(SyncBatchNorm, True, True) | |||||
| worker() | |||||
| def test_bn_no_track_stat(): | def test_bn_no_track_stat(): | ||||
| @@ -112,3 +149,11 @@ def test_trace_bn_forward_twice(): | |||||
| x = np.ones((1, 1, 32, 32), dtype=np.float32) | x = np.ones((1, 1, 32, 32), dtype=np.float32) | ||||
| y = train_bn(x, net=Simple()) | y = train_bn(x, net=Simple()) | ||||
| np.testing.assert_equal(y.numpy(), 0) | np.testing.assert_equal(y.numpy(), 0) | ||||
| # https://github.com/MegEngine/MegEngine/issues/145 | |||||
| def test_frozen_bn_no_affine(): | |||||
| nchannel = 3 | |||||
| m = BatchNorm2d(nchannel, freeze=True, affine=False) | |||||
| data = megengine.Tensor(np.random.random((6, nchannel, 2, 2)).astype("float32")) | |||||
| m(data).numpy() | |||||