GitOrigin-RevId: cf3bf8cb80
tags/v1.3.0
| @@ -27,9 +27,31 @@ from .utils import setscalar | |||
| _ElwMod = Elemwise.Mode | |||
| def _elwise(*args, mode): | |||
| def _elwise_apply(args, mode): | |||
| op = builtin.Elemwise(mode) | |||
| if mode in (_ElwMod.TRUE_DIV, _ElwMod.POW): | |||
| _isscalar = True | |||
| for i in args: | |||
| if isscalar(i) == False: | |||
| _isscalar = False | |||
| break | |||
| (result,) = apply(op, *args) | |||
| if _isscalar: | |||
| setscalar(result) | |||
| return result | |||
| def _elwise(*args, mode): | |||
| if mode in ( | |||
| _ElwMod.TRUE_DIV, | |||
| _ElwMod.POW, | |||
| _ElwMod.CEIL, | |||
| _ElwMod.FLOOR, | |||
| _ElwMod.ROUND, | |||
| ): | |||
| if mode in (_ElwMod.CEIL, _ElwMod.FLOOR, _ElwMod.ROUND) and np.issubdtype( | |||
| args[0].dtype, np.integer | |||
| ): | |||
| return args[0] | |||
| args = tuple( | |||
| map( | |||
| lambda x: x.astype("float32") | |||
| @@ -39,16 +61,7 @@ def _elwise(*args, mode): | |||
| ) | |||
| ) | |||
| args = utils.convert_inputs(*args) | |||
| (result,) = apply(op, *args) | |||
| _isscalar = True | |||
| for i in args: | |||
| if isscalar(i) == False: | |||
| _isscalar = False | |||
| break | |||
| if _isscalar: | |||
| setscalar(result) | |||
| return result | |||
| return _elwise_apply(args, mode) | |||
| def _matmul(inp1, inp2): | |||
| @@ -9,10 +9,13 @@ | |||
| # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order | |||
| import functools | |||
| import numpy as np | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core.ops import builtin | |||
| from ..core.ops.builtin import Elemwise | |||
| from ..core.tensor import megbrain_graph, utils | |||
| from ..core.tensor.array_method import _elwise_apply | |||
| from ..core.tensor.utils import isscalar, setscalar | |||
| from ..device import get_default_device | |||
| from ..jit.tracing import is_tracing | |||
| @@ -74,7 +77,6 @@ __all__ = [ | |||
| def _elwise(*args, mode): | |||
| op = builtin.Elemwise(mode) | |||
| tensor_args = list( | |||
| filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args) | |||
| ) | |||
| @@ -84,17 +86,33 @@ def _elwise(*args, mode): | |||
| args = utils.convert_inputs(first_arg, *args[1:]) | |||
| else: | |||
| args = utils.convert_inputs(*args) | |||
| if mode in ("true_div", "exp", "pow", "log", "expm1", "log1p"): | |||
| if mode in ( | |||
| Elemwise.Mode.TRUE_DIV, | |||
| Elemwise.Mode.EXP, | |||
| Elemwise.Mode.POW, | |||
| Elemwise.Mode.LOG, | |||
| Elemwise.Mode.EXPM1, | |||
| Elemwise.Mode.LOG1P, | |||
| Elemwise.Mode.TANH, | |||
| Elemwise.Mode.ACOS, | |||
| Elemwise.Mode.ASIN, | |||
| Elemwise.Mode.ATAN2, | |||
| Elemwise.Mode.CEIL, | |||
| Elemwise.Mode.COS, | |||
| Elemwise.Mode.FLOOR, | |||
| Elemwise.Mode.H_SWISH, | |||
| Elemwise.Mode.ROUND, | |||
| Elemwise.Mode.SIGMOID, | |||
| Elemwise.Mode.SIN, | |||
| ): | |||
| if mode in ( | |||
| Elemwise.Mode.CEIL, | |||
| Elemwise.Mode.FLOOR, | |||
| Elemwise.Mode.ROUND, | |||
| ) and np.issubdtype(args[0].dtype, np.integer): | |||
| return args[0] | |||
| args = tuple(map(lambda x: x.astype("float32"), args)) | |||
| _isscalar = True | |||
| for i in args: | |||
| if isscalar(i) == False: | |||
| _isscalar = False | |||
| break | |||
| (result,) = apply(op, *args) | |||
| if _isscalar: | |||
| setscalar(result) | |||
| return result | |||
| return _elwise_apply(args, mode) | |||
| def _elemwise_multi_type(*args, mode, **kwargs): | |||
| @@ -9,6 +9,7 @@ | |||
| import numpy as np | |||
| import megengine.functional as F | |||
| import megengine.functional.elemwise as elemwise | |||
| from megengine import tensor | |||
| from megengine.core.tensor import dtype | |||
| from megengine.functional.elemwise import _elwise | |||
| @@ -166,3 +167,20 @@ def test_qadd(): | |||
| result_mge = result_mge.astype("float32").numpy() | |||
| result_expect = x.astype("float32").numpy() + y.astype("float32").numpy() | |||
| np.testing.assert_almost_equal(result_mge, result_expect, decimal=6) | |||
| def test_int32_input(): | |||
| x = tensor(np.array([1, 2, 3, 4, 5]), dtype="int32") | |||
| for op_name in elemwise.__all__: | |||
| op = getattr(elemwise, op_name) | |||
| nargs = op.__code__.co_argcount | |||
| if op_name == "clip": | |||
| inp = (x, 0, 1) | |||
| elif op_name.endswith("_shift"): | |||
| inp = (x, 1) | |||
| elif op_name.startswith("logical_"): | |||
| continue | |||
| else: | |||
| inp = (x,) * nargs | |||
| y = op(*inp) | |||
| y.numpy() | |||