GitOrigin-RevId: efc6377197
tags/v1.3.0
| @@ -784,10 +784,10 @@ def sync_batch_norm( | |||
| if is_distributed(): | |||
| # reduce all nodes' data to calculate mean and variance | |||
| reduce_size = broadcast_to(Tensor(reduce_size, dtype=_dtype), [1] * _ndim) | |||
| stat = concat( | |||
| [reduce_size.astype(_dtype), channel_x1s, channel_x2s], axis=1 | |||
| reduce_size = broadcast_to( | |||
| Tensor(reduce_size).astype(dtype=_dtype), [1] * _ndim | |||
| ) | |||
| stat = concat([reduce_size, channel_x1s, channel_x2s], axis=1) | |||
| stat = all_reduce_sum(stat, group) | |||
| reduce_size = stat[:, :1].reshape(1) | |||
| channel_x1s = stat[:, 1 : 1 + _channels] | |||
| @@ -18,6 +18,7 @@ from .core._wrap import device as as_device | |||
| from .core.ops.builtin import Copy, GetVarShape | |||
| from .core.tensor.array_method import ArrayMethodMixin | |||
| from .device import _valid_device, get_default_device | |||
| from .logger import get_logger | |||
| from .utils.deprecation import deprecated | |||
| @@ -41,6 +42,10 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
| cn = device._cn | |||
| if isinstance(data, _Tensor): | |||
| if dtype is not None: | |||
| get_logger().warning( | |||
| "dtype does not work when creating a new Tensor with another Tensor" | |||
| ) | |||
| obj = _Tensor.__new__(cls, data) | |||
| else: | |||
| if isinstance(data, np.ndarray): | |||
| @@ -17,7 +17,7 @@ import megengine.optimizer as optimizer | |||
| from megengine import Parameter, tensor | |||
| from megengine.distributed.helper import get_device_count_by_fork | |||
| from megengine.jit import trace | |||
| from megengine.module import BatchNorm2d, Module, SyncBatchNorm | |||
| from megengine.module import BatchNorm2d, Conv2d, Module, Sequential, SyncBatchNorm | |||
| def run_frozen_bn(BNModule, use_trace=False, use_symbolic=False): | |||
| @@ -68,7 +68,7 @@ def test_frozen_bn(): | |||
| run_frozen_bn(BatchNorm2d, True, True) | |||
| @pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") | |||
| @pytest.mark.require_ngpu(2) | |||
| @pytest.mark.isolated_distributed | |||
| def test_frozen_synced_bn(): | |||
| @dist.launcher(n_gpus=2) | |||
| @@ -151,6 +151,45 @@ def test_trace_bn_forward_twice(): | |||
| np.testing.assert_equal(y.numpy(), 0) | |||
| def run_syncbn(trace_mode): | |||
| x = F.ones([2, 16, 4, 4], dtype="float32") | |||
| net = Sequential( | |||
| Conv2d(16, 16, 1), SyncBatchNorm(16), Conv2d(16, 16, 1), SyncBatchNorm(16), | |||
| ) | |||
| gm = ad.GradManager().attach( | |||
| net.parameters(), callbacks=dist.make_allreduce_cb("MEAN") | |||
| ) | |||
| opt = optimizer.SGD(net.parameters(), 1e-3) | |||
| def train_func(x): | |||
| with gm: | |||
| y = net(x) | |||
| loss = y.mean() | |||
| gm.backward(loss) | |||
| opt.step().clear_grad() | |||
| return loss | |||
| if trace_mode is not None: | |||
| train_func = trace(train_func, symbolic=trace_mode) | |||
| for _ in range(3): | |||
| loss = train_func(x) | |||
| loss.numpy() | |||
| @pytest.mark.require_ngpu(2) | |||
| @pytest.mark.isolated_distributed | |||
| @pytest.mark.parametrize("trace_mode", [None, True, False]) | |||
| def test_trace_several_syncbn(trace_mode): | |||
| @dist.launcher(n_gpus=2) | |||
| def worker(): | |||
| run_syncbn(trace_mode) | |||
| worker() | |||
| # https://github.com/MegEngine/MegEngine/issues/145 | |||
| def test_frozen_bn_no_affine(): | |||
| nchannel = 3 | |||
| @@ -226,8 +226,14 @@ void DelayBroadcastPass::apply(OptState& opt) const { | |||
| if (!prev) | |||
| prev = rewriter.get_var(opr->input(inp_idx)); | |||
| if (!opr->same_type<opr::Broadcast>()) { | |||
| VarNodeArray new_inp = opr->input(); | |||
| new_inp.at(inp_idx) = prev; | |||
| VarNodeArray new_inp(opr->input().size()); | |||
| for (size_t i = 0; i < opr->input().size(); i++) { | |||
| if (i == inp_idx) { | |||
| new_inp[i] = prev; | |||
| } else { | |||
| new_inp[i] = rewriter.get_var(opr->input(i)); | |||
| } | |||
| } | |||
| opt.call_with_opr(opr, [&] { | |||
| // create new opr with the original opr's properties | |||
| auto new_opr = serialization::copy_opr_shallow( | |||
| @@ -177,6 +177,32 @@ TEST_PASS(DelayBroadcastPass, LongChain) { | |||
| ASSERT_EQ(bcast(bcast(relu(relu(x)), y), z), out); | |||
| } | |||
| TEST_PASS(DelayBroadcastPass, ElemwiseChain) { | |||
| auto typecvt = [](SymbolVar x) { | |||
| return opr::TypeCvt::make(x, dtype::Int32()); | |||
| }; | |||
| auto reduce = [](SymbolVar x) { | |||
| SymbolVar tshp = x.make_scalar(1); | |||
| opr::Reduce::Param param_default{opr::Reduce::Mode::SUM, INT_MAX, | |||
| opr::Reduce::Param::DataType::DEFAULT}; | |||
| return opr::Reduce::make(x, param_default, tshp); | |||
| }; | |||
| auto shp = TensorShape{2, 2}; | |||
| auto x = mkvar("x", {1, 1}); | |||
| auto val = x.make_scalar(3); | |||
| auto out = reduce(typecvt(x.broadcast(shp))) + val.broadcast(shp); | |||
| out = gopt::GraphOptimizer{}. | |||
| add_pass<gopt::DelayBroadcastPass>(). | |||
| apply({{out}}).endpoint_vars()[0]; | |||
| auto expected = (reduce(typecvt(x).broadcast(shp)) + val).broadcast(shp); | |||
| ASSERT_EQ(out, expected); | |||
| } | |||
| TEST_PASS(ExpandVirtualGradPass, Simple) { | |||
| auto x = mkvar("x"); | |||
| check(x * 2, | |||