| @@ -222,45 +222,40 @@ def _normalize_axis( | |||
| raise | |||
| _opr_map = { | |||
| ("-", 1): builtin.Elemwise(mode="negate"), | |||
| ("fma3", 3): builtin.Elemwise(mode="FUSE_MUL_ADD3"), | |||
| ("fma4", 4): builtin.Elemwise(mode="FUSE_MUL_ADD4"), | |||
| } | |||
| for name, mode in [ | |||
| ("+", "add"), | |||
| ("-", "sub"), | |||
| ("*", "mul"), | |||
| ("/", "true_div"), | |||
| ("//", "floor_div"), | |||
| ("**", "pow"), | |||
| ("max", "max"), | |||
| ("additive", "add"), | |||
| ]: | |||
| _opr_map[(name, 2)] = builtin.Elemwise(mode=mode) | |||
| def subgraph(name, dtype, device, nr_inputs, gopt_level=None): | |||
| if device.physical_name.startswith("cpu"): | |||
| gopt_level = None # disable jit and compile | |||
| binary_ops = { | |||
| "+": lambda: builtin.Elemwise(mode="add"), | |||
| "-": lambda: builtin.Elemwise(mode="sub"), | |||
| "*": lambda: builtin.Elemwise(mode="mul"), | |||
| "/": lambda: builtin.Elemwise(mode="true_div"), | |||
| "//": lambda: builtin.Elemwise(mode="floor_div"), | |||
| "**": lambda: builtin.Elemwise(mode="pow"), | |||
| "√": lambda: builtin.Elemwise(mode="expm1"), | |||
| "max": lambda: builtin.Elemwise(mode="max"), | |||
| "additive": lambda: builtin.Elemwise(mode="add"), | |||
| } | |||
| unary_ops = { | |||
| "-": lambda: builtin.Elemwise(mode="negate"), | |||
| } | |||
| ternary_ops = { | |||
| "fma3": lambda: builtin.Elemwise(mode="FUSE_MUL_ADD3"), | |||
| } | |||
| quaternary_ops = {"fma4": lambda: builtin.Elemwise(mode="FUSE_MUL_ADD4")} | |||
| def as_op(op, nargs): | |||
| if isinstance(op, str): | |||
| assert (op, nargs) in _opr_map, "unknown operator" | |||
| op = _opr_map[(op, nargs)] | |||
| return op | |||
| def decorator(func): | |||
| builder = _SubgraphBuilder(name) | |||
| def apply_expr(op, *args, nr_out=None): | |||
| if isinstance(op, str): | |||
| if len(args) == 2: | |||
| op = binary_ops[op]() | |||
| elif len(args) == 1: | |||
| op = unary_ops[op]() | |||
| elif len(args) == 3: | |||
| op = ternary_ops[op]() | |||
| elif len(args) == 4: | |||
| op = quaternary_ops[op]() | |||
| op = as_op(op, len(args)) | |||
| results = builder.apply(op, args, 1 if nr_out is None else nr_out) | |||
| if nr_out is None: | |||
| assert len(results) == 1 | |||
| @@ -282,3 +277,40 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None): | |||
| return lambda: builder.compile(gopt_level) | |||
| return decorator | |||
| def interpret_subgraph(func, dtype, device): | |||
| def as_op(op, nargs): | |||
| if isinstance(op, str) and (op, nargs) in _opr_map: | |||
| op = _opr_map[(op, nargs)] | |||
| return op | |||
| def decorated_func(*args): | |||
| def apply_expr(op, *args, nr_out=None): | |||
| op = as_op(op, len(args)) | |||
| results = apply(op, *args) | |||
| if nr_out is None: | |||
| assert len(results) == 1 | |||
| return results[0] | |||
| else: | |||
| assert len(results) == nr_out | |||
| return results | |||
| def apply_const(value, dtype=dtype, device=device): | |||
| return Const(value, dtype=dtype, device=device)()[0] | |||
| outputs, outputs_has_grad = func(args, apply_expr, apply_const) | |||
| return outputs | |||
| return decorated_func | |||
| def subgraph_fn(name, dtype, device, nr_inputs, gopt_level=None, interpret=False): | |||
| def decorator(func): | |||
| if not interpret: | |||
| op = subgraph(name, dtype, device, nr_inputs, gopt_level=gopt_level)(func) | |||
| return lambda *args: apply(op(), *args) | |||
| else: | |||
| return interpret_subgraph(func, dtype, device) | |||
| return decorator | |||
| @@ -0,0 +1,108 @@ | |||
| import functools | |||
| import numpy as np | |||
| import pytest | |||
| import megengine | |||
| from megengine.autodiff.grad_manager import GradManager | |||
| from megengine.core.ops.builtin import GetVarShape, Reduce, TypeCvt | |||
| from megengine.core.tensor.utils import subgraph_fn | |||
| from megengine.device import CompNode, get_default_device | |||
| from megengine.jit import trace | |||
| _assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6) | |||
| @functools.lru_cache(maxsize=None) | |||
| def _get_batch_norm_fn(dtype, device, channels, ndim, interpret, gopt_level): | |||
| @subgraph_fn( | |||
| "BatchNormNd", | |||
| dtype=dtype, | |||
| device=device, | |||
| nr_inputs=4, | |||
| interpret=interpret, | |||
| gopt_level=gopt_level, | |||
| ) | |||
| def batch_norm_nd(inputs, f, c): | |||
| input, eps, weight, bias = inputs[0:4] | |||
| reduce_shape = c( | |||
| (1, channels) + (1,) * (ndim - 2), dtype="int32", device=device | |||
| ) | |||
| input_shape = f(GetVarShape(), input) | |||
| input_elems = f(Reduce(mode="product", axis=0), input_shape) | |||
| reduce_elems = f(Reduce(mode="product", axis=0), reduce_shape) | |||
| reduce_size = f("//", input_elems, reduce_elems) | |||
| reduce_size = f(TypeCvt(dtype=dtype), reduce_size) | |||
| channel_x1s = f(Reduce(mode="sum"), input, reduce_shape) | |||
| channel_x2s = f(Reduce(mode="sum_sqr"), input, reduce_shape) | |||
| channel_mean = f("/", channel_x1s, reduce_size) | |||
| channel_var = f( | |||
| "-", f("/", channel_x2s, reduce_size), f("*", channel_mean, channel_mean), | |||
| ) | |||
| invsqrt_channel_var = f("**", f("+", channel_var, eps), c(-0.5)) | |||
| inv_var_wt = f("*", invsqrt_channel_var, weight) | |||
| neg_channel_mean = f("-", channel_mean) | |||
| outvar = f( | |||
| "fma3", input, inv_var_wt, f("fma3", neg_channel_mean, inv_var_wt, bias), | |||
| ) | |||
| return (outvar,), (True,) | |||
| return batch_norm_nd | |||
| @pytest.mark.parametrize("device", [get_default_device(), "cpux"]) | |||
| @pytest.mark.parametrize("batch_size", [1, 8]) | |||
| @pytest.mark.parametrize("channels", [3]) | |||
| @pytest.mark.parametrize( | |||
| "use_trace, symbolic", [(False, None), (True, False), (True, True)] | |||
| ) | |||
| @pytest.mark.parametrize("gopt_level", [None, 1, 2]) | |||
| @pytest.mark.parametrize("dtype", ["float32"]) | |||
| def test_subgraph(device, batch_size, channels, use_trace, symbolic, gopt_level, dtype): | |||
| device = CompNode(device) | |||
| def subgraph_batch_norm(inp, weight, bias, eps, diff): | |||
| inp = inp.detach() | |||
| with GradManager().attach(inp) as gm: | |||
| batch_norm_fn = _get_batch_norm_fn( | |||
| dtype, device, channels, ndim, interpret=False, gopt_level=gopt_level | |||
| ) | |||
| out, *_ = batch_norm_fn(inp, eps, weight, bias) | |||
| gm.backward(out * 1e3 + 1e3, diff) | |||
| return out, inp.grad | |||
| def primitive_batch_norm(inp, weight, bias, eps, diff): | |||
| inp = inp.detach() | |||
| with GradManager().attach(inp) as gm: | |||
| batch_norm_fn = _get_batch_norm_fn( | |||
| dtype, device, channels, ndim, interpret=True, gopt_level=gopt_level | |||
| ) | |||
| (out,) = batch_norm_fn(inp, eps, weight, bias) | |||
| gm.backward(out * 1e3 + 1e3, diff) | |||
| return out, inp.grad | |||
| if use_trace: | |||
| subgraph_batch_norm = trace(symbolic=symbolic)(subgraph_batch_norm) | |||
| primitive_batch_norm = trace(symbolic=symbolic)(primitive_batch_norm) | |||
| def rand_tensor(shape, dtype=dtype, device=device): | |||
| return megengine.tensor(np.random.random(shape), dtype=dtype, device=device) | |||
| # test shape change | |||
| for image_shape in [(223, 223), (10, 20)]: | |||
| ndim = len(image_shape) + 2 | |||
| input_shape = (batch_size, channels) + image_shape | |||
| param_shape = (1, channels) + (1,) * len(image_shape) | |||
| inp = rand_tensor(input_shape) * 1e3 + 1e3 | |||
| weight = rand_tensor(param_shape) | |||
| bias = rand_tensor(param_shape) | |||
| eps = megengine.tensor(1e-5, dtype=dtype, device=device) | |||
| diff = rand_tensor(input_shape) | |||
| out1, grad1 = subgraph_batch_norm(inp, weight, bias, eps, diff) | |||
| out2, grad2 = primitive_batch_norm(inp, weight, bias, eps, diff) | |||
| _assert_allclose(out1.numpy(), out2.numpy()) | |||
| _assert_allclose(grad1.numpy(), grad2.numpy()) | |||
| @@ -15,6 +15,7 @@ import pytest | |||
| import megengine as mge | |||
| import megengine.distributed as dist | |||
| from megengine import Tensor | |||
| from megengine.autodiff.grad_manager import GradManager | |||
| from megengine.core._trace_option import use_symbolic_shape | |||
| from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm | |||
| @@ -337,3 +338,33 @@ def test_syncbn2d_no_stats(): | |||
| yv_expect = (xv - mean) / sd | |||
| _assert_allclose(yv.numpy(), yv_expect) | |||
| def test_syncbn2d_grad(): | |||
| nr_chan = 8 | |||
| data_shape = (3, nr_chan, 16, 16) | |||
| syncbn = SyncBatchNorm(8, track_running_stats=False) | |||
| bn = BatchNorm2d(8, track_running_stats=False) | |||
| for i in range(4): | |||
| if i == 2: | |||
| syncbn.training = False | |||
| bn.training = False | |||
| inp = Tensor(np.random.normal(loc=2.3, size=data_shape).astype(np.float32)) | |||
| diff = Tensor(np.random.normal(size=data_shape).astype(np.float32)) | |||
| with GradManager().attach(inp) as gm: | |||
| oup = syncbn(inp) | |||
| gm.backward(oup, diff) | |||
| grad = inp.grad | |||
| inp.grad = None | |||
| with GradManager().attach(inp) as gm: | |||
| oup_expect = bn(inp) | |||
| gm.backward(oup_expect, diff) | |||
| grad_expect = inp.grad | |||
| inp.grad = None | |||
| _assert_allclose(oup.numpy(), oup_expect.numpy()) | |||
| _assert_allclose(grad.numpy(), grad_expect.numpy()) | |||