|
|
|
@@ -10,8 +10,9 @@ |
|
|
|
import functools |
|
|
|
|
|
|
|
from ..core.ops import builtin |
|
|
|
from ..core.tensor import utils |
|
|
|
from ..core.tensor import megbrain_graph, utils |
|
|
|
from ..core.tensor.core import apply |
|
|
|
from ..device import get_default_device |
|
|
|
from ..tensor import Tensor |
|
|
|
|
|
|
|
__all__ = [ |
|
|
|
@@ -76,11 +77,17 @@ __all__ = [ |
|
|
|
|
|
|
|
def _elwise(*args, mode): |
|
|
|
op = builtin.Elemwise(mode=mode) |
|
|
|
tensor_args = list( |
|
|
|
filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args) |
|
|
|
) |
|
|
|
if len(tensor_args) == 0: |
|
|
|
dtype = utils.dtype_promotion(args) |
|
|
|
first_arg = Tensor(args[0], dtype=dtype, device=get_default_device()) |
|
|
|
args = utils.convert_inputs(first_arg, *args[1:]) |
|
|
|
else: |
|
|
|
args = utils.convert_inputs(*args) |
|
|
|
if mode in ("true_div", "exp", "pow", "log", "expm1", "log1p"): |
|
|
|
args = tuple( |
|
|
|
map(lambda x: x.astype("float32") if hasattr(x, "dtype") else x, args) |
|
|
|
) |
|
|
|
args = utils.convert_inputs(*args) |
|
|
|
args = tuple(map(lambda x: x.astype("float32"), args)) |
|
|
|
(result,) = apply(op, *args) |
|
|
|
return result |
|
|
|
|
|
|
|
|