GitOrigin-RevId: 57be5cec7b
tags/v1.1.0
| @@ -72,7 +72,27 @@ __all__ = [ | |||||
| ] | ] | ||||
| class _ElemwiseMode(Elemwise.Mode): | |||||
| @classmethod | |||||
| def __normalize(cls, val): | |||||
| if isinstance(val, str): | |||||
| if not hasattr(cls, "__member_upper_dict__"): | |||||
| cls.__member_upper_dict__ = { | |||||
| k.upper(): v for k, v in cls.__members__.items() | |||||
| } | |||||
| val = cls.__member_upper_dict__.get(val.upper(), val) | |||||
| return val | |||||
| @classmethod | |||||
| def convert(cls, val): | |||||
| val = cls.__normalize(val) | |||||
| if isinstance(val, cls): | |||||
| return val | |||||
| return cls(val) | |||||
| def _elwise(*args, mode): | def _elwise(*args, mode): | ||||
| mode = _ElemwiseMode.convert(mode) | |||||
| op = builtin.Elemwise(mode) | op = builtin.Elemwise(mode) | ||||
| tensor_args = list( | tensor_args = list( | ||||
| filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args) | filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args) | ||||
| @@ -73,11 +73,9 @@ class Elemwise(Module): | |||||
| * "NOT": bool unary: ~x | * "NOT": bool unary: ~x | ||||
| """ | """ | ||||
| _elemwise_mode_type = P.Elemwise.Mode | |||||
| def __init__(self, method): | def __init__(self, method): | ||||
| super().__init__() | super().__init__() | ||||
| self.method = self._elemwise_mode_type.convert(method) | |||||
| self.method = method | |||||
| def forward(self, *inps): | def forward(self, *inps): | ||||
| return _elwise(*inps, mode=self.method) | return _elwise(*inps, mode=self.method) | ||||
| @@ -28,4 +28,4 @@ class Elemwise(Float.Elemwise, QATModule): | |||||
| Return a :class:`~.QATModule` instance converted from | Return a :class:`~.QATModule` instance converted from | ||||
| a float :class:`~.Module` instance. | a float :class:`~.Module` instance. | ||||
| """ | """ | ||||
| return cls(float_module.method.name) | |||||
| return cls(float_module.method) | |||||
| @@ -33,4 +33,4 @@ class Elemwise(QuantizedModule): | |||||
| Return a :class:`~.QuantizedModule` instance converted from a | Return a :class:`~.QuantizedModule` instance converted from a | ||||
| :class:`~.QATModule` instance. | :class:`~.QATModule` instance. | ||||
| """ | """ | ||||
| return cls(qat_module.method.name, qat_module.get_activation_dtype()) | |||||
| return cls(qat_module.method, qat_module.get_activation_dtype()) | |||||
| @@ -10,6 +10,7 @@ import numpy as np | |||||
| import megengine.functional as F | import megengine.functional as F | ||||
| from megengine import tensor | from megengine import tensor | ||||
| from megengine.functional.elemwise import _elwise | |||||
| def test_abs(): | def test_abs(): | ||||
| @@ -21,6 +22,17 @@ def test_abs(): | |||||
| np.testing.assert_allclose(F.abs(-3.0).numpy(), np.abs(np.float32(-3.0))) | np.testing.assert_allclose(F.abs(-3.0).numpy(), np.abs(np.float32(-3.0))) | ||||
| def test_elemwise_mode_string(): | |||||
| np.testing.assert_allclose( | |||||
| _elwise(tensor([-3.0, -4.0, -5.0]), mode="ABS").numpy(), | |||||
| np.abs(np.array([-3.0, -4.0, -5.0], dtype=np.float32)), | |||||
| ) | |||||
| np.testing.assert_allclose( | |||||
| _elwise(-3.0, mode="ABS").numpy(), np.abs(np.float32(-3.0)) | |||||
| ) | |||||
| def test_multiply(): | def test_multiply(): | ||||
| np.testing.assert_allclose( | np.testing.assert_allclose( | ||||
| F.mul(-3.0, -4.0).numpy(), np.multiply(np.float32(-3.0), np.float32(-4.0)) | F.mul(-3.0, -4.0).numpy(), np.multiply(np.float32(-3.0), np.float32(-4.0)) | ||||
| @@ -0,0 +1,30 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import numpy as np | |||||
| import megengine.functional as F | |||||
| from megengine import tensor | |||||
| from megengine.module import Elemwise | |||||
| def test_module_elemwise(): | |||||
| def test_func(method, *inps): | |||||
| elemwise = Elemwise(method) | |||||
| outputs = elemwise(*inps) | |||||
| return outputs.numpy() | |||||
| x = np.random.rand(100).astype("float32") | |||||
| y = np.random.rand(100).astype("float32") | |||||
| x, y = tensor(x), tensor(y) | |||||
| np.testing.assert_almost_equal( | |||||
| test_func("H_SWISH", x), F.hswish(x).numpy(), decimal=6 | |||||
| ) | |||||
| np.testing.assert_almost_equal( | |||||
| test_func("ADD", x, y), F.add(x, y).numpy(), decimal=6 | |||||
| ) | |||||