GitOrigin-RevId: d51fad9867
tags/v1.1.0
| @@ -30,7 +30,6 @@ from ..tensor.core import apply | |||||
| from ..tensor.function import Function | from ..tensor.function import Function | ||||
| from ..tensor.tensor_wrapper import TensorWrapper | from ..tensor.tensor_wrapper import TensorWrapper | ||||
| _elemwise_add_param = Elemwise(mode="add").to_c().param | |||||
| _reduce_sum_param = Reduce(mode="SUM").to_c().param[0] | _reduce_sum_param = Reduce(mode="SUM").to_c().param[0] | ||||
| @@ -44,12 +43,12 @@ def _(op: OpDef, inputs, outputs, input_requires_grad): | |||||
| if isinstance(op, OprAttr): | if isinstance(op, OprAttr): | ||||
| grad_fn = _oprAttr_grad_fn.get(op.type, None) | grad_fn = _oprAttr_grad_fn.get(op.type, None) | ||||
| if grad_fn is None: | if grad_fn is None: | ||||
| if op.type == Elemwise.name and op.param == _elemwise_add_param: | |||||
| grad_fn = elemwise_add_grad_fn | |||||
| elif op.type == Reduce.name and op.param[0] == _reduce_sum_param: | |||||
| if op.type == Reduce.name and op.param[0] == _reduce_sum_param: | |||||
| grad_fn = reduce_sum_grad_fn | grad_fn = reduce_sum_grad_fn | ||||
| else: | else: | ||||
| grad_fn = default_grad_fn | grad_fn = default_grad_fn | ||||
| elif isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD: | |||||
| grad_fn = elemwise_add_grad_fn | |||||
| else: | else: | ||||
| grad_fn = default_grad_fn | grad_fn = default_grad_fn | ||||
| return grad_fn(op, inputs, outputs, input_requires_grad) | return grad_fn(op, inputs, outputs, input_requires_grad) | ||||
| @@ -158,11 +157,8 @@ def subtensor_grad_fn(op, inputs, outputs, input_requires_grad): | |||||
| params = inputs[1:] | params = inputs[1:] | ||||
| def make_grad(grad_op, dy): | def make_grad(grad_op, dy): | ||||
| grad = ( | |||||
| TensorWrapper(0, dtype=dy.dtype, device=dy.device) | |||||
| ._broadcast(TensorWrapper(input_shape)) | |||||
| .__wrapped__ | |||||
| ) | |||||
| (_z,) = Const(0, dtype=dy.dtype, device=dy.device)(dy) | |||||
| (grad,) = apply(Broadcast(), _z, input_shape) | |||||
| (dx,) = apply(grad_op, grad, dy, *params) | (dx,) = apply(grad_op, grad, dy, *params) | ||||
| return dx | return dx | ||||
| @@ -184,11 +180,8 @@ def indexingMultiAxisVec_grad_fn(op, inputs, outputs, input_requires_grad): | |||||
| params = inputs[1:] | params = inputs[1:] | ||||
| def make_grad(grad_op, dy): | def make_grad(grad_op, dy): | ||||
| grad = ( | |||||
| TensorWrapper(0, dtype=dy.dtype, device=dy.device) | |||||
| ._broadcast(TensorWrapper(input_shape)) | |||||
| .__wrapped__ | |||||
| ) | |||||
| (_z,) = Const(0, dtype=dy.dtype, device=dy.device)(dy) | |||||
| (grad,) = apply(Broadcast(), _z, input_shape) | |||||
| (dx,) = apply(grad_op, grad, dy, *params) | (dx,) = apply(grad_op, grad, dy, *params) | ||||
| return dx | return dx | ||||
| @@ -47,7 +47,7 @@ def get_grad_managers(): | |||||
| def add(a, b): | def add(a, b): | ||||
| (c,) = apply(Elemwise(mode="add"), a, b) | |||||
| (c,) = apply(Elemwise(Elemwise.Mode.ADD), a, b) | |||||
| return c | return c | ||||
| @@ -13,7 +13,7 @@ import numpy as np | |||||
| from .._trace_option import use_symbolic_shape | from .._trace_option import use_symbolic_shape | ||||
| from ..ops import builtin | from ..ops import builtin | ||||
| from ..ops.builtin import GetVarShape | |||||
| from ..ops.builtin import Elemwise, GetVarShape | |||||
| from ..ops.special import Const | from ..ops.special import Const | ||||
| from . import utils | from . import utils | ||||
| from .core import OpBase, TensorBase, TensorWrapperBase, apply | from .core import OpBase, TensorBase, TensorWrapperBase, apply | ||||
| @@ -23,10 +23,12 @@ from .raw_tensor import RawTensor, as_raw_tensor | |||||
| from .tensor import Tensor | from .tensor import Tensor | ||||
| from .utils import make_shape_tuple as _make_shape_tuple | from .utils import make_shape_tuple as _make_shape_tuple | ||||
| _ElwMod = Elemwise.Mode | |||||
| def _elwise(*args, mode): | def _elwise(*args, mode): | ||||
| op = builtin.Elemwise(mode=mode) | |||||
| if mode in ("TRUE_DIV", "POW"): | |||||
| op = builtin.Elemwise(mode) | |||||
| if mode in (_ElwMod.TRUE_DIV, _ElwMod.POW): | |||||
| args = tuple( | args = tuple( | ||||
| map( | map( | ||||
| lambda x: x.astype("float32") | lambda x: x.astype("float32") | ||||
| @@ -272,53 +274,53 @@ class ArrayMethodMixin(abc.ABC): | |||||
| __hash__ = None # due to __eq__ diviates from python convention | __hash__ = None # due to __eq__ diviates from python convention | ||||
| __lt__ = lambda self, value: _elwise(self, value, mode="LT").astype("bool") | |||||
| __le__ = lambda self, value: _elwise(self, value, mode="LEQ").astype("bool") | |||||
| __gt__ = lambda self, value: _elwise(value, self, mode="LT").astype("bool") | |||||
| __ge__ = lambda self, value: _elwise(value, self, mode="LEQ").astype("bool") | |||||
| __eq__ = lambda self, value: _elwise(self, value, mode="EQ").astype("bool") | |||||
| __lt__ = lambda self, value: _elwise(self, value, mode=_ElwMod.LT).astype("bool") | |||||
| __le__ = lambda self, value: _elwise(self, value, mode=_ElwMod.LEQ).astype("bool") | |||||
| __gt__ = lambda self, value: _elwise(value, self, mode=_ElwMod.LT).astype("bool") | |||||
| __ge__ = lambda self, value: _elwise(value, self, mode=_ElwMod.LEQ).astype("bool") | |||||
| __eq__ = lambda self, value: _elwise(self, value, mode=_ElwMod.EQ).astype("bool") | |||||
| __ne__ = lambda self, value: _elwise( | __ne__ = lambda self, value: _elwise( | ||||
| _elwise(self, value, mode="EQ").astype("bool"), mode="NOT" | |||||
| _elwise(self, value, mode=_ElwMod.EQ).astype("bool"), mode=_ElwMod.NOT, | |||||
| ) | ) | ||||
| __neg__ = _unary_elwise("NEGATE") | |||||
| __neg__ = _unary_elwise(_ElwMod.NEGATE) | |||||
| __pos__ = lambda self: self | __pos__ = lambda self: self | ||||
| __abs__ = _unary_elwise("ABS") | |||||
| __invert__ = _logical_unary_elwise("NOT") | |||||
| __round__ = _unary_elwise("ROUND") | |||||
| __abs__ = _unary_elwise(_ElwMod.ABS) | |||||
| __invert__ = _logical_unary_elwise(_ElwMod.NOT) | |||||
| __round__ = _unary_elwise(_ElwMod.ROUND) | |||||
| __trunc__ = _todo | __trunc__ = _todo | ||||
| __floor__ = _unary_elwise("FLOOR") | |||||
| __ceil__ = _unary_elwise("CEIL") | |||||
| __floor__ = _unary_elwise(_ElwMod.FLOOR) | |||||
| __ceil__ = _unary_elwise(_ElwMod.CEIL) | |||||
| __add__ = _binary_elwise("ADD") | |||||
| __sub__ = _binary_elwise("SUB") | |||||
| __mul__ = _binary_elwise("MUL") | |||||
| __add__ = _binary_elwise(_ElwMod.ADD) | |||||
| __sub__ = _binary_elwise(_ElwMod.SUB) | |||||
| __mul__ = _binary_elwise(_ElwMod.MUL) | |||||
| __matmul__ = lambda self, other: _matmul(self, other) | __matmul__ = lambda self, other: _matmul(self, other) | ||||
| __truediv__ = _binary_elwise("TRUE_DIV") | |||||
| __floordiv__ = _binary_elwise("FLOOR_DIV") | |||||
| __mod__ = _binary_elwise("MOD") | |||||
| __truediv__ = _binary_elwise(_ElwMod.TRUE_DIV) | |||||
| __floordiv__ = _binary_elwise(_ElwMod.FLOOR_DIV) | |||||
| __mod__ = _binary_elwise(_ElwMod.MOD) | |||||
| # __divmode__ | # __divmode__ | ||||
| __pow__ = _binary_elwise("POW") | |||||
| __lshift__ = _binary_elwise("SHL") | |||||
| __rshift__ = _binary_elwise("SHR") | |||||
| __and__ = _logical_binary_elwise("AND") | |||||
| __or__ = _logical_binary_elwise("OR") | |||||
| __xor__ = _logical_binary_elwise("XOR") | |||||
| __radd__ = _binary_elwise("ADD", rev=1) | |||||
| __rsub__ = _binary_elwise("SUB", rev=1) | |||||
| __rmul__ = _binary_elwise("MUL", rev=1) | |||||
| __pow__ = _binary_elwise(_ElwMod.POW) | |||||
| __lshift__ = _binary_elwise(_ElwMod.SHL) | |||||
| __rshift__ = _binary_elwise(_ElwMod.SHR) | |||||
| __and__ = _logical_binary_elwise(_ElwMod.AND) | |||||
| __or__ = _logical_binary_elwise(_ElwMod.OR) | |||||
| __xor__ = _logical_binary_elwise(_ElwMod.XOR) | |||||
| __radd__ = _binary_elwise(_ElwMod.ADD, rev=1) | |||||
| __rsub__ = _binary_elwise(_ElwMod.SUB, rev=1) | |||||
| __rmul__ = _binary_elwise(_ElwMod.MUL, rev=1) | |||||
| __rmatmul__ = lambda self, other: _matmul(other, self) | __rmatmul__ = lambda self, other: _matmul(other, self) | ||||
| __rtruediv__ = _binary_elwise("TRUE_DIV", rev=1) | |||||
| __rfloordiv__ = _binary_elwise("FLOOR_DIV", rev=1) | |||||
| __rmod__ = _binary_elwise("MOD", rev=1) | |||||
| __rtruediv__ = _binary_elwise(_ElwMod.TRUE_DIV, rev=1) | |||||
| __rfloordiv__ = _binary_elwise(_ElwMod.FLOOR_DIV, rev=1) | |||||
| __rmod__ = _binary_elwise(_ElwMod.MOD, rev=1) | |||||
| # __rdivmode__ | # __rdivmode__ | ||||
| __rpow__ = _binary_elwise("POW", rev=1) | |||||
| __rlshift__ = _binary_elwise("SHL", rev=1) | |||||
| __rrshift__ = _binary_elwise("SHR", rev=1) | |||||
| __rand__ = _logical_binary_elwise("AND", rev=1) | |||||
| __ror__ = _logical_binary_elwise("OR", rev=1) | |||||
| __rxor__ = _logical_binary_elwise("XOR", rev=1) | |||||
| __rpow__ = _binary_elwise(_ElwMod.POW, rev=1) | |||||
| __rlshift__ = _binary_elwise(_ElwMod.SHL, rev=1) | |||||
| __rrshift__ = _binary_elwise(_ElwMod.SHR, rev=1) | |||||
| __rand__ = _logical_binary_elwise(_ElwMod.AND, rev=1) | |||||
| __ror__ = _logical_binary_elwise(_ElwMod.OR, rev=1) | |||||
| __rxor__ = _logical_binary_elwise(_ElwMod.XOR, rev=1) | |||||
| __iadd__ = _inplace(__add__) | __iadd__ = _inplace(__add__) | ||||
| __isub__ = _inplace(__sub__) | __isub__ = _inplace(__sub__) | ||||
| @@ -10,6 +10,7 @@ | |||||
| import functools | import functools | ||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..core.ops.builtin import Elemwise | |||||
| from ..core.tensor import megbrain_graph, utils | from ..core.tensor import megbrain_graph, utils | ||||
| from ..core.tensor.core import apply | from ..core.tensor.core import apply | ||||
| from ..device import get_default_device | from ..device import get_default_device | ||||
| @@ -72,7 +73,7 @@ __all__ = [ | |||||
| def _elwise(*args, mode): | def _elwise(*args, mode): | ||||
| op = builtin.Elemwise(mode=mode) | |||||
| op = builtin.Elemwise(mode) | |||||
| tensor_args = list( | tensor_args = list( | ||||
| filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args) | filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args) | ||||
| ) | ) | ||||
| @@ -128,67 +129,67 @@ def add(x, y): | |||||
| [ 6. 8. 10.]] | [ 6. 8. 10.]] | ||||
| """ | """ | ||||
| return _elwise(x, y, mode="add") | |||||
| return _elwise(x, y, mode=Elemwise.Mode.ADD) | |||||
| def sub(x, y): | def sub(x, y): | ||||
| """Element-wise `subtraction`.""" | """Element-wise `subtraction`.""" | ||||
| return _elwise(x, y, mode="sub") | |||||
| return _elwise(x, y, mode=Elemwise.Mode.SUB) | |||||
| def mul(x, y): | def mul(x, y): | ||||
| """Element-wise `multiplication`.""" | """Element-wise `multiplication`.""" | ||||
| return _elwise(x, y, mode="mul") | |||||
| return _elwise(x, y, mode=Elemwise.Mode.MUL) | |||||
| def div(x, y): | def div(x, y): | ||||
| """Element-wise `(x / y)`.""" | """Element-wise `(x / y)`.""" | ||||
| return _elwise(x, y, mode="true_div") | |||||
| return _elwise(x, y, mode=Elemwise.Mode.TRUE_DIV) | |||||
| def floor_div(x, y): | def floor_div(x, y): | ||||
| """Element-wise `floor(x / y)`.""" | """Element-wise `floor(x / y)`.""" | ||||
| return _elwise(x, y, mode="floor_divide") | |||||
| return _elwise(x, y, mode=Elemwise.Mode.FLOOR_DIVIDE) | |||||
| def neg(x): | def neg(x): | ||||
| """Element-wise `negation`.""" | """Element-wise `negation`.""" | ||||
| return _elwise(x, mode="negate") | |||||
| return _elwise(x, mode=Elemwise.Mode.NEGATE) | |||||
| def pow(x, y): | def pow(x, y): | ||||
| """Element-wise `power`.""" | """Element-wise `power`.""" | ||||
| return _elwise(x, y, mode="pow") | |||||
| return _elwise(x, y, mode=Elemwise.Mode.POW) | |||||
| def mod(x, y): | def mod(x, y): | ||||
| """Element-wise `remainder of division`.""" | """Element-wise `remainder of division`.""" | ||||
| return _elwise(x, y, mode="mod") | |||||
| return _elwise(x, y, mode=Elemwise.Mode.MOD) | |||||
| def abs(x): | def abs(x): | ||||
| """Element-wise `absolute value`.""" | """Element-wise `absolute value`.""" | ||||
| return _elwise(x, mode="abs") | |||||
| return _elwise(x, mode=Elemwise.Mode.ABS) | |||||
| def exp(x): | def exp(x): | ||||
| """Element-wise `exponential`.""" | """Element-wise `exponential`.""" | ||||
| return _elwise(x, mode="exp") | |||||
| return _elwise(x, mode=Elemwise.Mode.EXP) | |||||
| def expm1(x): | def expm1(x): | ||||
| """Element-wise `exp(x)-1`.""" | """Element-wise `exp(x)-1`.""" | ||||
| return _elwise(x, mode="expm1") | |||||
| return _elwise(x, mode=Elemwise.Mode.EXPM1) | |||||
| def log(x): | def log(x): | ||||
| """Element-wise `logarithm (base e)`.""" | """Element-wise `logarithm (base e)`.""" | ||||
| return _elwise(x, mode="log") | |||||
| return _elwise(x, mode=Elemwise.Mode.LOG) | |||||
| def log1p(x): | def log1p(x): | ||||
| """Element-wise `log(x+1) (base e)`.""" | """Element-wise `log(x+1) (base e)`.""" | ||||
| return _elwise(x, mode="log1p") | |||||
| return _elwise(x, mode=Elemwise.Mode.LOG1P) | |||||
| def sqrt(x: Tensor) -> Tensor: | def sqrt(x: Tensor) -> Tensor: | ||||
| @@ -253,27 +254,27 @@ def square(x: Tensor) -> Tensor: | |||||
| def round(x): | def round(x): | ||||
| """Element-wise `rounding to int`.""" | """Element-wise `rounding to int`.""" | ||||
| return _elwise(x, mode="round") | |||||
| return _elwise(x, mode=Elemwise.Mode.ROUND) | |||||
| def ceil(x): | def ceil(x): | ||||
| """Element-wise `ceiling`.""" | """Element-wise `ceiling`.""" | ||||
| return _elwise(x, mode="ceil") | |||||
| return _elwise(x, mode=Elemwise.Mode.CEIL) | |||||
| def floor(x): | def floor(x): | ||||
| """Element-wise `floor`.""" | """Element-wise `floor`.""" | ||||
| return _elwise(x, mode="floor") | |||||
| return _elwise(x, mode=Elemwise.Mode.FLOOR) | |||||
| def maximum(x, y): | def maximum(x, y): | ||||
| """Element-wise `maximum of array elements`.""" | """Element-wise `maximum of array elements`.""" | ||||
| return _elwise(x, y, mode="max") | |||||
| return _elwise(x, y, mode=Elemwise.Mode.MAX) | |||||
| def minimum(x, y): | def minimum(x, y): | ||||
| """Element-wise `minimum of array elements`.""" | """Element-wise `minimum of array elements`.""" | ||||
| return _elwise(x, y, mode="min") | |||||
| return _elwise(x, y, mode=Elemwise.Mode.MIN) | |||||
| # trigonometric functions | # trigonometric functions | ||||
| @@ -305,12 +306,12 @@ def cos(x): | |||||
| [-0.99 -0.6536 0.2837]] | [-0.99 -0.6536 0.2837]] | ||||
| """ | """ | ||||
| return _elwise(x, mode="cos") | |||||
| return _elwise(x, mode=Elemwise.Mode.COS) | |||||
| def sin(x): | def sin(x): | ||||
| """Element-wise `sine`.""" | """Element-wise `sine`.""" | ||||
| return _elwise(x, mode="sin") | |||||
| return _elwise(x, mode=Elemwise.Mode.SIN) | |||||
| def tan(x): | def tan(x): | ||||
| @@ -320,22 +321,22 @@ def tan(x): | |||||
| def acos(x): | def acos(x): | ||||
| """Element-wise `inverse cosine`.""" | """Element-wise `inverse cosine`.""" | ||||
| return _elwise(x, mode="acos") | |||||
| return _elwise(x, mode=Elemwise.Mode.ACOS) | |||||
| def asin(x): | def asin(x): | ||||
| """Element-wise `inverse sine`.""" | """Element-wise `inverse sine`.""" | ||||
| return _elwise(x, mode="asin") | |||||
| return _elwise(x, mode=Elemwise.Mode.ASIN) | |||||
| def atan(x): | def atan(x): | ||||
| """Element-wise `inverse tangent`.""" | """Element-wise `inverse tangent`.""" | ||||
| return _elwise(x, 1, mode="atan2") | |||||
| return _elwise(x, 1, mode=Elemwise.Mode.ATAN2) | |||||
| def atan2(y, x): | def atan2(y, x): | ||||
| """Element-wise `2-argument arctangent`.""" | """Element-wise `2-argument arctangent`.""" | ||||
| return _elwise(y, x, mode="atan2") | |||||
| return _elwise(y, x, mode=Elemwise.Mode.ATAN2) | |||||
| def cosh(x): | def cosh(x): | ||||
| @@ -351,7 +352,7 @@ def sinh(x): | |||||
| def tanh(x): | def tanh(x): | ||||
| r"""Element-wise `hyperbolic tangent`.""" | r"""Element-wise `hyperbolic tangent`.""" | ||||
| return _elwise(x, mode="tanh") | |||||
| return _elwise(x, mode=Elemwise.Mode.TANH) | |||||
| def asinh(x): | def asinh(x): | ||||
| @@ -399,12 +400,12 @@ def left_shift(x, y): | |||||
| [12 16 20]] | [12 16 20]] | ||||
| """ | """ | ||||
| return _elwise(x, y, mode="shl") | |||||
| return _elwise(x, y, mode=Elemwise.Mode.SHL) | |||||
| def right_shift(x, y): | def right_shift(x, y): | ||||
| """Element-wise `bitwise binary: x >> y`.""" | """Element-wise `bitwise binary: x >> y`.""" | ||||
| return _elwise(x, y, mode="shr") | |||||
| return _elwise(x, y, mode=Elemwise.Mode.SHR) | |||||
| # logical functions | # logical functions | ||||
| @@ -412,22 +413,22 @@ def right_shift(x, y): | |||||
| def logical_and(x, y): | def logical_and(x, y): | ||||
| """Element-wise `logical and: x && y`.""" | """Element-wise `logical and: x && y`.""" | ||||
| return _elwise(x, y, mode="AND") | |||||
| return _elwise(x, y, mode=Elemwise.Mode.AND) | |||||
| def logical_not(x): | def logical_not(x): | ||||
| """Element-wise `logical not: ~x`.""" | """Element-wise `logical not: ~x`.""" | ||||
| return _elwise(x, mode="NOT") | |||||
| return _elwise(x, mode=Elemwise.Mode.NOT) | |||||
| def logical_or(x, y): | def logical_or(x, y): | ||||
| """Element-wise `logical or: x || y`.""" | """Element-wise `logical or: x || y`.""" | ||||
| return _elwise(x, y, mode="OR") | |||||
| return _elwise(x, y, mode=Elemwise.Mode.OR) | |||||
| def logical_xor(x, y): | def logical_xor(x, y): | ||||
| """Element-wise `logical xor: x ^ y`.""" | """Element-wise `logical xor: x ^ y`.""" | ||||
| return _elwise(x, y, mode="XOR") | |||||
| return _elwise(x, y, mode=Elemwise.Mode.XOR) | |||||
| # comparison functions | # comparison functions | ||||
| @@ -461,7 +462,7 @@ def equal(x, y): | |||||
| [1. 1. 1.]] | [1. 1. 1.]] | ||||
| """ | """ | ||||
| return _elwise(x, y, mode="eq") | |||||
| return _elwise(x, y, mode=Elemwise.Mode.EQ) | |||||
| def not_equal(x, y): | def not_equal(x, y): | ||||
| @@ -471,22 +472,22 @@ def not_equal(x, y): | |||||
| def less(x, y): | def less(x, y): | ||||
| """Element-wise `(x < y)`.""" | """Element-wise `(x < y)`.""" | ||||
| return _elwise(x, y, mode="lt") | |||||
| return _elwise(x, y, mode=Elemwise.Mode.LT) | |||||
| def less_equal(x, y): | def less_equal(x, y): | ||||
| """Element-wise `(x <= y)`.""" | """Element-wise `(x <= y)`.""" | ||||
| return _elwise(x, y, mode="leq") | |||||
| return _elwise(x, y, mode=Elemwise.Mode.LEQ) | |||||
| def greater(x, y): | def greater(x, y): | ||||
| """Element-wise `(x > y)`.""" | """Element-wise `(x > y)`.""" | ||||
| return _elwise(y, x, mode="lt") | |||||
| return _elwise(y, x, mode=Elemwise.Mode.LT) | |||||
| def greater_equal(x, y): | def greater_equal(x, y): | ||||
| """Element-wise `(x >= y)`.""" | """Element-wise `(x >= y)`.""" | ||||
| return _elwise(y, x, mode="leq") | |||||
| return _elwise(y, x, mode=Elemwise.Mode.LEQ) | |||||
| # other functions | # other functions | ||||
| @@ -515,7 +516,7 @@ def hswish(x): | |||||
| [0. 0.6667 1.6667 3. 4. ] | [0. 0.6667 1.6667 3. 4. ] | ||||
| """ | """ | ||||
| return _elwise(x, mode="h_swish") | |||||
| return _elwise(x, mode=Elemwise.Mode.H_SWISH) | |||||
| def hsigmoid(x): | def hsigmoid(x): | ||||
| @@ -525,7 +526,7 @@ def hsigmoid(x): | |||||
| def relu(x): | def relu(x): | ||||
| """Element-wise `max(x, 0)`.""" | """Element-wise `max(x, 0)`.""" | ||||
| return _elwise(x, mode="relu") | |||||
| return _elwise(x, mode=Elemwise.Mode.RELU) | |||||
| def relu6(x): | def relu6(x): | ||||
| @@ -535,7 +536,7 @@ def relu6(x): | |||||
| def sigmoid(x): | def sigmoid(x): | ||||
| """Element-wise `1 / ( 1 + exp( -x ) )`.""" | """Element-wise `1 / ( 1 + exp( -x ) )`.""" | ||||
| return _elwise(x, mode="sigmoid") | |||||
| return _elwise(x, mode=Elemwise.Mode.SIGMOID) | |||||
| def clip(x: Tensor, lower=None, upper=None) -> Tensor: | def clip(x: Tensor, lower=None, upper=None) -> Tensor: | ||||
| @@ -12,6 +12,7 @@ from typing import Optional, Sequence, Tuple, Union | |||||
| from ..core._imperative_rt import CompNode | from ..core._imperative_rt import CompNode | ||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..core.ops._internal import param_defs as P | from ..core.ops._internal import param_defs as P | ||||
| from ..core.ops.builtin import BatchNorm | |||||
| from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
| from ..core.tensor import megbrain_graph, utils | from ..core.tensor import megbrain_graph, utils | ||||
| from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | ||||
| @@ -643,19 +644,22 @@ def batch_norm( | |||||
| if inp.ndim != 4: | if inp.ndim != 4: | ||||
| raise NotImplementedError("batch_norm for ndim != 4") | raise NotImplementedError("batch_norm for ndim != 4") | ||||
| def full_value(value): | |||||
| C = inp.shape[1] | |||||
| (x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) | |||||
| return broadcast_to(x, [1, C, 1, 1]) | |||||
| def expand_or_full(x, value): | |||||
| if x is None: | |||||
| return full_value(value) | |||||
| return expand_dims(x, [0, 2, 3]) | |||||
| C = inp.shape[1] | |||||
| def make_full_if_none(x, value): | def make_full_if_none(x, value): | ||||
| if x is None: | if x is None: | ||||
| return full(shape=(1, inp.shape[1], 1, 1), value=value) | |||||
| (x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) | |||||
| shape = utils.astensor1d( | |||||
| (1, C, 1, 1), inp, dtype="int32", device=inp.device | |||||
| ) | |||||
| (result,) = apply(builtin.Broadcast(), x, shape) | |||||
| return result | |||||
| elif x.ndim == 1: | |||||
| shape = utils.astensor1d( | |||||
| (1, C, 1, 1), inp, dtype="int32", device=inp.device | |||||
| ) | |||||
| (result,) = apply(builtin.Reshape(), x, shape) | |||||
| return result | |||||
| return x | return x | ||||
| has_mean = running_mean is not None | has_mean = running_mean is not None | ||||
| @@ -674,19 +678,25 @@ def batch_norm( | |||||
| inp, weight, bias, running_mean, running_var | inp, weight, bias, running_mean, running_var | ||||
| ) | ) | ||||
| weight = expand_or_full(weight, 1) | |||||
| bias = expand_or_full(bias, 0) | |||||
| weight = make_full_if_none(weight, 1) | |||||
| bias = make_full_if_none(bias, 0) | |||||
| if not training: | if not training: | ||||
| op = builtin.BatchNorm(fwd_mode="INFERENCE", epsilon=eps, param_dim="DIM_1C11") | |||||
| op = builtin.BatchNorm( | |||||
| BatchNorm.ParamDim.DIM_1C11, BatchNorm.FwdMode.INFERENCE, eps, 1.0, 1.0, 0.0 | |||||
| ) | |||||
| ret = apply(op, inp, weight, bias, running_mean, running_var)[-1] | ret = apply(op, inp, weight, bias, running_mean, running_var)[-1] | ||||
| return ret | return ret | ||||
| else: | else: | ||||
| op = builtin.BatchNorm( | op = builtin.BatchNorm( | ||||
| avg_factor=1 - momentum, epsilon=eps, param_dim="DIM_1C11" | |||||
| BatchNorm.ParamDim.DIM_1C11, | |||||
| BatchNorm.FwdMode.TRAINING, | |||||
| eps, | |||||
| 1.0 - momentum, | |||||
| 1.0, | |||||
| 0.0, | |||||
| ) | ) | ||||
| if has_mean or has_var: | if has_mean or has_var: | ||||
| running_mean = make_full_if_none(running_mean, 0) | running_mean = make_full_if_none(running_mean, 0) | ||||
| running_var = make_full_if_none(running_var, 1) | running_var = make_full_if_none(running_var, 1) | ||||
| @@ -708,7 +718,7 @@ def batch_norm( | |||||
| else: | else: | ||||
| return inp, new_mean, new_var | return inp, new_mean, new_var | ||||
| else: | else: | ||||
| _, _, inp, = apply(op, inp, weight, bias) | |||||
| (_, _, inp,) = apply(op, inp, weight, bias) | |||||
| return inp | return inp | ||||
| @@ -72,14 +72,15 @@ class _BatchNorm(Module): | |||||
| self.track_running_stats == False | self.track_running_stats == False | ||||
| ), "track_running_stats can not be initilized to False and changed to True later" | ), "track_running_stats can not be initilized to False and changed to True later" | ||||
| _ndims = len(inp.shape) | |||||
| inp_shape = inp.shape | |||||
| _ndims = len(inp_shape) | |||||
| if _ndims != 4: | if _ndims != 4: | ||||
| origin_shape = inp.shape | |||||
| origin_shape = inp_shape | |||||
| if _ndims == 2: | if _ndims == 2: | ||||
| n, c = inp.shape[0], inp.shape[1] | |||||
| n, c = inp_shape[0], inp_shape[1] | |||||
| new_shape = (n, c, 1, 1) | new_shape = (n, c, 1, 1) | ||||
| elif _ndims == 3: | elif _ndims == 3: | ||||
| n, c, h = inp.shape[0], inp.shape[1], inp.shape[2] | |||||
| n, c, h = inp_shape[0], inp_shape[1], inp_shape[2] | |||||
| new_shape = (n, c, h, 1) | new_shape = (n, c, h, 1) | ||||
| inp = inp.reshape(new_shape) | inp = inp.reshape(new_shape) | ||||
| @@ -150,17 +151,18 @@ class SyncBatchNorm(_BatchNorm): | |||||
| def forward(self, inp): | def forward(self, inp): | ||||
| self._check_input_ndim(inp) | self._check_input_ndim(inp) | ||||
| _ndims = len(inp.shape) | |||||
| inp_shape = inp.shape | |||||
| _ndims = len(inp_shape) | |||||
| if _ndims != 4: | if _ndims != 4: | ||||
| new_shape = Tensor([1, 1, 1, 1], device=inp.device) | new_shape = Tensor([1, 1, 1, 1], device=inp.device) | ||||
| origin_shape = inp.shape | |||||
| origin_shape = inp_shape | |||||
| if _ndims == 2: | if _ndims == 2: | ||||
| new_shape[:2] = origin_shape[:2] | new_shape[:2] = origin_shape[:2] | ||||
| elif _ndims == 3: | elif _ndims == 3: | ||||
| new_shape[:3] = origin_shape[:3] | new_shape[:3] = origin_shape[:3] | ||||
| else: | else: | ||||
| raise ValueError( | raise ValueError( | ||||
| "expected 2D, 3D or 4D input (got {}D input)".format(len(inp.shape)) | |||||
| "expected 2D, 3D or 4D input (got {}D input)".format(len(inp_shape)) | |||||
| ) | ) | ||||
| inp = inp.reshape(new_shape) | inp = inp.reshape(new_shape) | ||||
| @@ -19,6 +19,8 @@ | |||||
| #include "megbrain/imperative/ops/io_remote.h" | #include "megbrain/imperative/ops/io_remote.h" | ||||
| #include "megbrain/imperative/ops/cond_take.h" | #include "megbrain/imperative/ops/cond_take.h" | ||||
| #include "megbrain/imperative/ops/nms.h" | #include "megbrain/imperative/ops/nms.h" | ||||
| #include "megbrain/imperative/ops/elemwise.h" | |||||
| #include "megbrain/imperative/ops/batch_norm.h" | |||||
| namespace py = pybind11; | namespace py = pybind11; | ||||
| @@ -117,4 +119,91 @@ void init_ops(py::module m) { | |||||
| .def(py::init<float, uint32_t>()) | .def(py::init<float, uint32_t>()) | ||||
| .def_readwrite("iou_thresh", &NMSKeep::iou_thresh) | .def_readwrite("iou_thresh", &NMSKeep::iou_thresh) | ||||
| .def_readwrite("max_output", &NMSKeep::max_output); | .def_readwrite("max_output", &NMSKeep::max_output); | ||||
| py::class_<Elemwise, std::shared_ptr<Elemwise>, OpDef> elemwise(m, "Elemwise"); | |||||
| elemwise.def(py::init<Elemwise::Mode>()) | |||||
| .def_readwrite("mode", &Elemwise::mode); | |||||
| #define V(m) .value(#m, Elemwise::Mode::m) | |||||
| py::enum_<Elemwise::Mode>(elemwise, "Mode") | |||||
| V(RELU) | |||||
| V(ABS) | |||||
| V(ACOS) | |||||
| V(ASIN) | |||||
| V(CEIL) | |||||
| V(COS) | |||||
| V(EXP) | |||||
| V(EXPM1) | |||||
| V(FLOOR) | |||||
| V(LOG) | |||||
| V(LOG1P) | |||||
| V(NEGATE) | |||||
| V(SIGMOID) | |||||
| V(SIN) | |||||
| V(TANH) | |||||
| V(ABS_GRAD) | |||||
| V(ADD) | |||||
| V(FLOOR_DIV) | |||||
| V(MAX) | |||||
| V(MIN) | |||||
| V(MOD) | |||||
| V(MUL) | |||||
| V(POW) | |||||
| V(SIGMOID_GRAD) | |||||
| V(SUB) | |||||
| V(SWITCH_GT0) | |||||
| V(TANH_GRAD) | |||||
| V(TRUE_DIV) | |||||
| V(LOG_SUM_EXP) | |||||
| V(LT) | |||||
| V(LEQ) | |||||
| V(EQ) | |||||
| V(SHL) | |||||
| V(SHR) | |||||
| V(COND_LEQ_MOV) | |||||
| V(FUSE_MUL_ADD3) | |||||
| V(FUSE_MUL_ADD4) | |||||
| V(FUSE_ADD_RELU) | |||||
| V(FUSE_ADD_SIGMOID) | |||||
| V(FUSE_ADD_TANH) | |||||
| V(FAST_TANH) | |||||
| V(FAST_TANH_GRAD) | |||||
| V(ROUND) | |||||
| V(RMULH) | |||||
| V(ATAN2) | |||||
| V(ERF) | |||||
| V(ERFINV) | |||||
| V(ERFC) | |||||
| V(ERFCINV) | |||||
| V(H_SWISH) | |||||
| V(H_SWISH_GRAD) | |||||
| V(FUSE_ADD_H_SWISH) | |||||
| V(NOT) | |||||
| V(AND) | |||||
| V(OR) | |||||
| V(XOR); | |||||
| #undef V | |||||
| py::class_<BatchNorm, std::shared_ptr<BatchNorm>, OpDef> batchnorm(m, "BatchNorm"); | |||||
| batchnorm.def(py::init<const BatchNorm::Param::ParamDim&, const BatchNorm::Param::FwdMode&, double, double, float, float>()) | |||||
| .def_readwrite("param_dim", &BatchNorm::param_dim) | |||||
| .def_readwrite("fwd_mode", &BatchNorm::fwd_mode) | |||||
| .def_readwrite("epsilon", &BatchNorm::epsilon) | |||||
| .def_readwrite("avg_factor", &BatchNorm::avg_factor) | |||||
| .def_readwrite("scale", &BatchNorm::scale) | |||||
| .def_readwrite("bias", &BatchNorm::bias); | |||||
| #define V(m) .value(#m, BatchNorm::Param::ParamDim::m) | |||||
| py::enum_<BatchNorm::Param::ParamDim>(batchnorm, "ParamDim") | |||||
| V(DIM_11HW) | |||||
| V(DIM_1CHW) | |||||
| V(DIM_1C11); | |||||
| #undef V | |||||
| #define V(m) .value(#m, BatchNorm::Param::FwdMode::m) | |||||
| py::enum_<BatchNorm::Param::FwdMode>(batchnorm, "FwdMode") | |||||
| V(TRAINING) | |||||
| V(INFERENCE); | |||||
| #undef V | |||||
| } | } | ||||
| @@ -27,7 +27,7 @@ from megengine.functional.distributed import remote_recv, remote_send | |||||
| def _elwise(mode): | def _elwise(mode): | ||||
| op = Elemwise(mode=mode) | |||||
| op = Elemwise(mode) | |||||
| def f(*args): | def f(*args): | ||||
| (result,) = apply(op, *args) | (result,) = apply(op, *args) | ||||
| @@ -36,10 +36,10 @@ def _elwise(mode): | |||||
| return f | return f | ||||
| add = _elwise("add") | |||||
| mul = _elwise("mul") | |||||
| cos = _elwise("cos") | |||||
| relu = _elwise("relu") | |||||
| add = _elwise(Elemwise.Mode.ADD) | |||||
| mul = _elwise(Elemwise.Mode.MUL) | |||||
| cos = _elwise(Elemwise.Mode.COS) | |||||
| relu = _elwise(Elemwise.Mode.RELU) | |||||
| def as_tensor(x): | def as_tensor(x): | ||||
| @@ -255,7 +255,7 @@ def test_elemwise_relu(): | |||||
| def test_elemwise_relu_backward_fn(): | def test_elemwise_relu_backward_fn(): | ||||
| op = Elemwise(mode="relu").to_c() | |||||
| op = Elemwise(Elemwise.Mode.RELU) | |||||
| attr = TensorAttr() | attr = TensorAttr() | ||||
| attr.dtype = "float32" | attr.dtype = "float32" | ||||
| attr.comp_node = "xpux" | attr.comp_node = "xpux" | ||||
| @@ -17,7 +17,7 @@ def elemwise(*args, mode): | |||||
| from megengine.core.ops.builtin import Elemwise | from megengine.core.ops.builtin import Elemwise | ||||
| from megengine.core._imperative_rt.imperative import apply_op | from megengine.core._imperative_rt.imperative import apply_op | ||||
| return apply_op(Elemwise(mode=mode).to_c(), args) | |||||
| return apply_op(Elemwise(mode), args) | |||||
| def test_basic_interface(): | def test_basic_interface(): | ||||
| @@ -37,13 +37,15 @@ def test_basic_interface(): | |||||
| def test_opr_attr(): | def test_opr_attr(): | ||||
| from megengine.core.ops.builtin import Elemwise | from megengine.core.ops.builtin import Elemwise | ||||
| assert Elemwise(mode="add") == Elemwise(mode="add") | |||||
| assert Elemwise(Elemwise.Mode.ADD) == Elemwise(Elemwise.Mode.ADD) | |||||
| def test_simple_arith(): | def test_simple_arith(): | ||||
| from megengine.core.ops.builtin import Elemwise | |||||
| x = np.random.rand(10).astype("float32") | x = np.random.rand(10).astype("float32") | ||||
| xx = megengine.core._imperative_rt.put(x) | xx = megengine.core._imperative_rt.put(x) | ||||
| (yy,) = elemwise(xx, xx, mode="mul") | |||||
| (yy,) = elemwise(xx, xx, mode=Elemwise.Mode.MUL) | |||||
| np.testing.assert_allclose(x * x, megengine.core._imperative_rt.get_value(yy)) | np.testing.assert_allclose(x * x, megengine.core._imperative_rt.get_value(yy)) | ||||
| megengine.core._imperative_rt.delete(xx) | megengine.core._imperative_rt.delete(xx) | ||||
| megengine.core._imperative_rt.delete(yy) | megengine.core._imperative_rt.delete(yy) | ||||
| @@ -64,7 +66,7 @@ def test_raw_tensor(): | |||||
| x = np.random.rand(10).astype("float32") | x = np.random.rand(10).astype("float32") | ||||
| xx = as_raw_tensor(x) | xx = as_raw_tensor(x) | ||||
| (yy,) = apply(Elemwise(mode="mul"), xx, xx) | |||||
| (yy,) = apply(Elemwise(Elemwise.Mode.MUL), xx, xx) | |||||
| np.testing.assert_allclose(x * x, yy.numpy()) | np.testing.assert_allclose(x * x, yy.numpy()) | ||||
| (yy,) = apply(Elemwise(mode="mul"), xx, xx) | |||||
| (yy,) = apply(Elemwise(Elemwise.Mode.MUL), xx, xx) | |||||
| np.testing.assert_allclose(x * x, yy.numpy()) | np.testing.assert_allclose(x * x, yy.numpy()) | ||||
| @@ -17,6 +17,7 @@ import megengine.functional as F | |||||
| from megengine import cgtools, tensor | from megengine import cgtools, tensor | ||||
| from megengine.core._trace_option import set_symbolic_shape | from megengine.core._trace_option import set_symbolic_shape | ||||
| from megengine.core.ops import builtin as ops | from megengine.core.ops import builtin as ops | ||||
| from megengine.core.ops.builtin import Elemwise | |||||
| from megengine.core.tensor.core import apply | from megengine.core.tensor.core import apply | ||||
| from megengine.core.tensor.raw_tensor import as_raw_tensor | from megengine.core.tensor.raw_tensor import as_raw_tensor | ||||
| from megengine.functional import exp, log | from megengine.functional import exp, log | ||||
| @@ -28,7 +29,7 @@ def test_trace(): | |||||
| @trace(symbolic=symbolic) | @trace(symbolic=symbolic) | ||||
| def f(x): | def f(x): | ||||
| op = ops.Elemwise(mode="negate") | |||||
| op = ops.Elemwise(Elemwise.Mode.NEGATE) | |||||
| (y,) = apply(op, x) | (y,) = apply(op, x) | ||||
| return y | return y | ||||
| @@ -44,7 +45,7 @@ def test_exclude_from_trace(): | |||||
| @trace(symbolic=symbolic) | @trace(symbolic=symbolic) | ||||
| def f(x): | def f(x): | ||||
| neg = ops.Elemwise(mode="negate") | |||||
| neg = ops.Elemwise(Elemwise.Mode.NEGATE) | |||||
| (x,) = apply(neg, x) | (x,) = apply(neg, x) | ||||
| with exclude_from_trace(): | with exclude_from_trace(): | ||||
| if i % 2: | if i % 2: | ||||
| @@ -65,7 +66,7 @@ def test_print_in_trace(): | |||||
| @trace(symbolic=symbolic) | @trace(symbolic=symbolic) | ||||
| def f(x): | def f(x): | ||||
| nonlocal buf | nonlocal buf | ||||
| neg = ops.Elemwise(mode="negate") | |||||
| neg = ops.Elemwise(Elemwise.Mode.NEGATE) | |||||
| (x,) = apply(neg, x) | (x,) = apply(neg, x) | ||||
| buf = x.numpy() | buf = x.numpy() | ||||
| (x,) = apply(neg, x) | (x,) = apply(neg, x) | ||||
| @@ -85,7 +86,7 @@ def test_print_in_trace(): | |||||
| def test_dump(): | def test_dump(): | ||||
| @trace(symbolic=True, capture_as_const=True) | @trace(symbolic=True, capture_as_const=True) | ||||
| def f(a, b): | def f(a, b): | ||||
| op = ops.Elemwise(mode="add") | |||||
| op = ops.Elemwise(Elemwise.Mode.ADD) | |||||
| (y,) = apply(op, a, b) | (y,) = apply(op, a, b) | ||||
| return y | return y | ||||
| @@ -111,7 +112,7 @@ def test_capture_dump(): | |||||
| @trace(symbolic=True, capture_as_const=True) | @trace(symbolic=True, capture_as_const=True) | ||||
| def f(x): | def f(x): | ||||
| op = ops.Elemwise(mode="mul") | |||||
| op = ops.Elemwise(Elemwise.Mode.MUL) | |||||
| (y,) = apply(op, x, a) | (y,) = apply(op, x, a) | ||||
| return y | return y | ||||
| @@ -133,7 +134,7 @@ def test_dump_volatile(): | |||||
| @trace(symbolic=True, capture_as_const=True) | @trace(symbolic=True, capture_as_const=True) | ||||
| def f(x): | def f(x): | ||||
| op = ops.Elemwise(mode="mul") | |||||
| op = ops.Elemwise(Elemwise.Mode.MUL) | |||||
| (y,) = apply(op, x, p) | (y,) = apply(op, x, p) | ||||
| return y | return y | ||||
| @@ -159,7 +160,7 @@ def test_trace_profiler(): | |||||
| @trace(symbolic=symbolic, profiling=True) | @trace(symbolic=symbolic, profiling=True) | ||||
| def f(x): | def f(x): | ||||
| op = ops.Elemwise(mode="negate") | |||||
| op = ops.Elemwise(Elemwise.Mode.NEGATE) | |||||
| (y,) = apply(op, x) | (y,) = apply(op, x) | ||||
| return y | return y | ||||
| @@ -0,0 +1,84 @@ | |||||
| /** | |||||
| * \file imperative/src/impl/ops/batch_norm.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "megbrain/imperative/ops/batch_norm.h" | |||||
| #include "../op_trait.h" | |||||
| namespace mgb { | |||||
| namespace imperative { | |||||
| namespace { | |||||
| std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||||
| auto* node = &node_->cast_final_safe<opr::BatchNorm>(); | |||||
| auto&& param = node->param(); | |||||
| return BatchNorm::make(param.param_dim, param.fwd_mode, param.epsilon, | |||||
| param.avg_factor, param.scale, param.bias); | |||||
| } | |||||
| cg::OperatorNodeBase* apply_on_var_node( | |||||
| const OpDef& def, | |||||
| const VarNodeArray& inputs) { | |||||
| auto&& bn_opr = def.cast_final_safe<BatchNorm>(); | |||||
| size_t nr_inp = inputs.size(); | |||||
| mgb_assert(nr_inp == 3 ||nr_inp == 5, | |||||
| "BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp); | |||||
| if (nr_inp == 3) { | |||||
| return opr::BatchNorm::make( | |||||
| inputs[0], inputs[1], inputs[2], | |||||
| {bn_opr.param_dim, bn_opr.fwd_mode, bn_opr.epsilon, bn_opr.avg_factor, bn_opr.scale, bn_opr.bias})[0] | |||||
| .node()->owner_opr(); | |||||
| } else { | |||||
| return opr::BatchNorm::make( | |||||
| inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], | |||||
| {bn_opr.param_dim, bn_opr.fwd_mode, bn_opr.epsilon, bn_opr.avg_factor, bn_opr.scale, bn_opr.bias})[0] | |||||
| .node()->owner_opr(); | |||||
| } | |||||
| } | |||||
| SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( | |||||
| const OpDef& def, | |||||
| const SmallVector<LogicalTensorDesc>& inputs) { | |||||
| auto&& op_def = def.cast_final_safe<BatchNorm>(); | |||||
| size_t nr_inp = inputs.size(); | |||||
| mgb_assert(nr_inp == 3 ||nr_inp == 5, | |||||
| "BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp); | |||||
| // need running mean/variance | |||||
| bool need_stat = (nr_inp == 5) && op_def.fwd_mode == BatchNorm::Param::FwdMode::TRAINING; | |||||
| size_t nr_out = need_stat? 5 : 3; | |||||
| SmallVector<LogicalTensorDesc> out_shapes(nr_out); | |||||
| auto&& i0 = inputs[0]; | |||||
| auto&& i1 = inputs[1]; | |||||
| size_t i = 0; | |||||
| if (!need_stat) { | |||||
| out_shapes[0] = out_shapes[1] = {TensorLayout({0}, i0.layout.dtype, i0.layout.format), i0.comp_node}; | |||||
| i = 2; | |||||
| } | |||||
| for (; i < nr_out-1; ++ i) { | |||||
| out_shapes[i] = {i1.layout, i1.comp_node}; | |||||
| } | |||||
| out_shapes[nr_out-1] = {i0.layout, i0.comp_node}; | |||||
| return out_shapes; | |||||
| } | |||||
| OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm) | |||||
| .make_from_op_node(make_from_op_node) | |||||
| .apply_on_var_node(apply_on_var_node) | |||||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||||
| .fallback(); | |||||
| } // anonymous namespace | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchNorm); | |||||
| } // namespace imperative | |||||
| } // namespace mgb | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| @@ -0,0 +1,78 @@ | |||||
| /** | |||||
| * \file imperative/src/impl/ops/elemwise.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "megbrain/imperative/ops/elemwise.h" | |||||
| #include "../op_trait.h" | |||||
| namespace mgb { | |||||
| namespace imperative { | |||||
| namespace { | |||||
| std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||||
| auto* node = &node_->cast_final_safe<opr::Elemwise>(); | |||||
| return Elemwise::make(node->param().mode); | |||||
| } | |||||
| cg::OperatorNodeBase* apply_on_var_node( | |||||
| const OpDef& def, | |||||
| const VarNodeArray& inputs) { | |||||
| auto&& elemwise_opr = def.cast_final_safe<Elemwise>(); | |||||
| return opr::Elemwise::make(inputs, elemwise_opr.mode).node()->owner_opr(); | |||||
| } | |||||
| SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( | |||||
| const OpDef& def, | |||||
| const SmallVector<LogicalTensorDesc>& inputs) { | |||||
| auto&& op_def = def.cast_final_safe<Elemwise>(); | |||||
| auto trait = Elemwise::ModeTrait::from_mode(op_def.mode); | |||||
| mgb_assert(inputs.size() == trait.arity, | |||||
| "%s expects %u inputs; got %zu actually", trait.name, | |||||
| trait.arity, inputs.size()); | |||||
| TensorShapeArray inp_shapes; | |||||
| DType out_dt; | |||||
| CompNode out_cn; | |||||
| for (size_t i = 0; i < inputs.size(); ++ i) { | |||||
| auto &&t = inputs[i]; | |||||
| if (!i) { | |||||
| out_cn = t.comp_node; | |||||
| out_dt = t.layout.dtype; | |||||
| } else { | |||||
| mgb_assert(t.comp_node == out_cn); | |||||
| mgb_assert(t.layout.dtype == out_dt); | |||||
| } | |||||
| if (t.layout.ndim > 0) { | |||||
| inp_shapes.push_back(t.layout); | |||||
| } else { | |||||
| TensorLayout out_layout; | |||||
| out_layout.ndim = 0; | |||||
| out_layout.dtype = out_dt; | |||||
| return {{out_layout, out_cn}}; | |||||
| } | |||||
| } | |||||
| auto&& out_shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes); | |||||
| return {{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}; | |||||
| } | |||||
| OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise) | |||||
| .make_from_op_node(make_from_op_node) | |||||
| .apply_on_var_node(apply_on_var_node) | |||||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||||
| .fallback(); | |||||
| } // anonymous namespace | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(Elemwise); | |||||
| } // namespace imperative | |||||
| } // namespace mgb | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| @@ -0,0 +1,70 @@ | |||||
| /** | |||||
| * \file imperative/src/include/megbrain/imperative/ops/batch_norm.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megbrain/opr/dnn/batch_norm.h" | |||||
| #include "megbrain/imperative/op_def.h" | |||||
| #include "megbrain/utils/hash.h" | |||||
| namespace mgb::imperative { | |||||
| class BatchNorm : public OpDefImplBase<BatchNorm> { | |||||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||||
| public: | |||||
| using Param = opr::BatchNorm::Param; | |||||
| Param::ParamDim param_dim; | |||||
| Param::FwdMode fwd_mode; | |||||
| double epsilon; | |||||
| double avg_factor; | |||||
| float scale; | |||||
| float bias; | |||||
| BatchNorm() = default; | |||||
| BatchNorm(const Param::ParamDim& param_dim_, const Param::FwdMode& fwd_mode_, | |||||
| double epsilon_, double avg_factor_, float scale_, float bias_) | |||||
| : param_dim(param_dim_), | |||||
| fwd_mode(fwd_mode_), | |||||
| epsilon(epsilon_), | |||||
| avg_factor(avg_factor_), | |||||
| scale(scale_), | |||||
| bias(bias_) {} | |||||
| size_t hash() const override { | |||||
| XXHash xxhash{}; | |||||
| auto append = [&xxhash](auto field){ | |||||
| auto hash_val = HashTrait<decltype(field)>::eval(field); | |||||
| xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val)); | |||||
| }; | |||||
| append(param_dim); | |||||
| append(fwd_mode); | |||||
| append(epsilon); | |||||
| append(avg_factor); | |||||
| append(scale); | |||||
| append(bias); | |||||
| return xxhash.digest(); | |||||
| } | |||||
| bool is_same_st(const Hashable& rhs_) const override { | |||||
| auto&& rhs = static_cast<const BatchNorm&>(rhs_); | |||||
| return rhs.param_dim == param_dim | |||||
| && rhs.fwd_mode == fwd_mode | |||||
| && rhs.epsilon == epsilon | |||||
| && rhs.avg_factor == avg_factor | |||||
| && rhs.scale == scale | |||||
| && rhs.bias == bias; | |||||
| } | |||||
| }; | |||||
| } // namespace mgb::imperative | |||||
| @@ -0,0 +1,42 @@ | |||||
| /** | |||||
| * \file imperative/src/include/megbrain/imperative/ops/elemwise.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megbrain/opr/basic_arith.h" | |||||
| #include "megbrain/imperative/op_def.h" | |||||
| namespace mgb::imperative { | |||||
| class Elemwise : public OpDefImplBase<Elemwise> { | |||||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||||
| public: | |||||
| using Mode = opr::Elemwise::Mode; | |||||
| using ModeTrait = megdnn::Elemwise::ModeTrait; | |||||
| Mode mode; | |||||
| Elemwise() = default; | |||||
| Elemwise(const Mode& mode_): mode(mode_) {} | |||||
| size_t hash() const override { | |||||
| return hash_pair_combine(mgb::hash(mode), reinterpret_cast<std::uintptr_t>(dyn_typeinfo())); | |||||
| } | |||||
| bool is_same_st(const Hashable& rhs_) const override { | |||||
| auto&& rhs = static_cast<const Elemwise&>(rhs_); | |||||
| return rhs.mode == mode; | |||||
| } | |||||
| }; | |||||
| } // namespace mgb::imperative | |||||