| @@ -21,6 +21,7 @@ from ..distributed import WORLD, is_distributed | |||||
| from ..jit.tracing import is_tracing | from ..jit.tracing import is_tracing | ||||
| from ..random import uniform | from ..random import uniform | ||||
| from ..tensor import Tensor | from ..tensor import Tensor | ||||
| from ..utils.tuple_function import _pair, _pair_nonzero | |||||
| from .debug_param import get_conv_execution_strategy | from .debug_param import get_conv_execution_strategy | ||||
| from .distributed import all_reduce_sum | from .distributed import all_reduce_sum | ||||
| from .elemwise import exp, floor, log, log1p, maximum, minimum, relu | from .elemwise import exp, floor, log, log1p, maximum, minimum, relu | ||||
| @@ -35,7 +36,6 @@ from .tensor import ( | |||||
| squeeze, | squeeze, | ||||
| zeros, | zeros, | ||||
| ) | ) | ||||
| from .types import _pair, _pair_nonzero | |||||
| __all__ = [ | __all__ = [ | ||||
| "adaptive_avg_pool2d", | "adaptive_avg_pool2d", | ||||
| @@ -11,8 +11,8 @@ from typing import Tuple, Union | |||||
| from ..core._imperative_rt.core2 import apply | from ..core._imperative_rt.core2 import apply | ||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..tensor import Tensor | from ..tensor import Tensor | ||||
| from ..utils.tuple_function import _pair, _pair_nonzero | |||||
| from .debug_param import get_conv_execution_strategy | from .debug_param import get_conv_execution_strategy | ||||
| from .types import _pair, _pair_nonzero | |||||
| def conv_bias_activation( | def conv_bias_activation( | ||||
| @@ -1,41 +0,0 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import collections | |||||
| import functools | |||||
| def get_ndtuple(value, *, n, allow_zero: bool = True): | |||||
| r""" | |||||
| Converts possibly 1D tuple to n-dim tuple. | |||||
| :param value: value will be filled in generated tuple. | |||||
| :param n: how many elements will the tuple have. | |||||
| :param allow_zero: whether to allow zero tuple value. | |||||
| :return: a tuple. | |||||
| """ | |||||
| if not isinstance(value, collections.Iterable): | |||||
| value = int(value) | |||||
| value = tuple([value for i in range(n)]) | |||||
| else: | |||||
| assert len(value) == n, "tuple len is not equal to n: {}".format(value) | |||||
| spatial_axis = map(int, value) | |||||
| value = tuple(spatial_axis) | |||||
| if allow_zero: | |||||
| minv = 0 | |||||
| else: | |||||
| minv = 1 | |||||
| assert min(value) >= minv, "invalid value: {}".format(value) | |||||
| return value | |||||
| _single = functools.partial(get_ndtuple, n=1, allow_zero=True) | |||||
| _pair = functools.partial(get_ndtuple, n=2, allow_zero=True) | |||||
| _pair_nonzero = functools.partial(get_ndtuple, n=2, allow_zero=False) | |||||
| _triple = functools.partial(get_ndtuple, n=3, allow_zero=True) | |||||
| _quadruple = functools.partial(get_ndtuple, n=4, allow_zero=True) | |||||
| @@ -11,8 +11,8 @@ from typing import Tuple, Union | |||||
| import numpy as np | import numpy as np | ||||
| from ..functional import conv1d, conv2d, conv_transpose2d, local_conv2d, relu | from ..functional import conv1d, conv2d, conv_transpose2d, local_conv2d, relu | ||||
| from ..functional.types import _pair, _pair_nonzero | |||||
| from ..tensor import Parameter | from ..tensor import Parameter | ||||
| from ..utils.tuple_function import _pair, _pair_nonzero | |||||
| from . import init | from . import init | ||||
| from .module import Module | from .module import Module | ||||